diff --git a/docker/rocm/migraphx/migraphx_py.cpp b/docker/rocm/migraphx/migraphx_py.cpp index 894c9d186..ee830bb47 100644 --- a/docker/rocm/migraphx/migraphx_py.cpp +++ b/docker/rocm/migraphx/migraphx_py.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,10 +25,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -40,12 +42,15 @@ #include #include #include - +#include +#include +#include +#include #ifdef HAVE_GPU #include #endif -using half = half_float::half; +using half = migraphx::half; namespace py = pybind11; #ifdef __clang__ @@ -144,6 +149,61 @@ struct npy_format_descriptor static constexpr auto name() { return _("half"); } }; +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for fp8 + return "z"; + } + static constexpr auto name() { return _("fp8e4m3fnuz"); } +}; + +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for fp8 + return "z"; + } + static constexpr auto name() { return _("fp8e5m2fnuz"); } +}; + +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for fp8 + return "z"; + } + static constexpr auto name() { return _("fp8e4m3fn"); } +}; + +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for fp8 + return "z"; + } + static constexpr auto name() { return _("fp8e5m2"); } +}; + +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for bf16 + return "z"; + } + static constexpr auto name() { return _("bf16"); } +}; + } // namespace detail } // namespace pybind11 @@ -201,10 +261,52 @@ py::buffer_info to_buffer_info(T& x) return b; } +py::object to_py_object(const migraphx::value& val) +{ + py::object result; + + val.visit_value([&](const auto& x) { + if constexpr(std::is_same, std::vector>{}) + { + if(val.is_object()) + { + py::dict py_dict; + for(const auto& item : x) + { + py_dict[py::str(item.get_key())] = to_py_object(item.without_key()); + } + result = py_dict; + } + else + { + py::list py_list; + for(const auto& item : x) + { + py_list.append(to_py_object(item)); + } + result = py_list; + } + } + else + { + result = py::cast(x); + } + }); + + return result; +} + migraphx::shape to_shape(const py::buffer_info& info) { migraphx::shape::type_t t; std::size_t n = 0; + // Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum + if(info.format == "z") + { + MIGRAPHX_THROW( + "MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using " + "migraphx.generate_argument with migraphx.add_literal"); + } visit_types([&](auto as) { if(info.format == py::format_descriptor::format() or (info.format == "l" and py::format_descriptor::format() == "q") or @@ -315,7 +417,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::class_(m, "instruction_ref") .def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); }) - .def("op", [](migraphx::instruction_ref i) { return i->get_operator(); }); + .def("op", [](migraphx::instruction_ref i) { return i->get_operator(); }) + .def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); }) + .def("name", [](migraphx::instruction_ref i) { return i->name(); }) + .def(py::hash(py::self)) + .def(py::self == py::self) + .def(py::self != py::self); py::class_>(m, "module") .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) @@ -330,6 +437,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("op"), py::arg("args"), py::arg("mod_args") = std::vector{}) + .def( + "add_literal", + [](migraphx::module& mm, migraphx::argument a) { + return mm.add_literal(a.get_shape(), a.data()); + }, + py::arg("data")) .def( "add_literal", [](migraphx::module& mm, py::buffer data) { @@ -351,7 +464,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) return mm.add_return(args); }, py::arg("args")) - .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }); + .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }) + .def( + "__iter__", + [](const migraphx::module& mm) { + auto r = migraphx::iterator_for(mm); + return py::make_iterator(r.begin(), r.end()); + }, + py::keep_alive<0, 1>()); py::class_(m, "program") .def(py::init([]() { return migraphx::program(); })) @@ -410,6 +530,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) migraphx::any_ptr(reinterpret_cast(stream), stream_name), true}; return p.eval(pm, exec_env); }) + .def("to_py", + [](const migraphx::program& p) { + std::stringstream ss; + p.print_py(ss); + return ss.str(); + }) .def("sort", &migraphx::program::sort) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("__eq__", std::equal_to{}) @@ -425,7 +551,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) } return migraphx::make_op(name, v); })) - .def("name", &migraphx::operation::name); + .def("name", &migraphx::operation::name) + .def("values", [](const migraphx::operation& operation) -> py::object { + return to_py_object(operation.to_value()); + }); py::enum_(op, "pooling_mode") .value("average", migraphx::op::pooling_mode::average) @@ -472,7 +601,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) map_dyn_input_dims, bool skip_unknown_operators, bool print_program_on_error, - int64_t max_loop_iterations) { + int64_t max_loop_iterations, + int64_t limit_max_iterations) { migraphx::onnx_options options; options.default_dim_value = default_dim_value; options.default_dyn_dim_value = default_dyn_dim_value; @@ -481,6 +611,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) options.skip_unknown_operators = skip_unknown_operators; options.print_program_on_error = print_program_on_error; options.max_loop_iterations = max_loop_iterations; + options.limit_max_iterations = limit_max_iterations; return migraphx::parse_onnx(filename, options); }, "Parse onnx file", @@ -492,7 +623,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) std::unordered_map>(), py::arg("skip_unknown_operators") = false, py::arg("print_program_on_error") = false, - py::arg("max_loop_iterations") = 10); + py::arg("max_loop_iterations") = 10, + py::arg("limit_max_iterations") = std::numeric_limits::max()); m.def( "parse_onnx_buffer", @@ -503,7 +635,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) std::unordered_map> map_dyn_input_dims, bool skip_unknown_operators, - bool print_program_on_error) { + bool print_program_on_error, + const std::string& external_data_path) { migraphx::onnx_options options; options.default_dim_value = default_dim_value; options.default_dyn_dim_value = default_dyn_dim_value; @@ -511,6 +644,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) options.map_dyn_input_dims = map_dyn_input_dims; options.skip_unknown_operators = skip_unknown_operators; options.print_program_on_error = print_program_on_error; + options.external_data_path = external_data_path; return migraphx::parse_onnx_buffer(onnx_buffer, options); }, "Parse onnx file", @@ -521,7 +655,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("map_dyn_input_dims") = std::unordered_map>(), py::arg("skip_unknown_operators") = false, - py::arg("print_program_on_error") = false); + py::arg("print_program_on_error") = false, + py::arg("external_data_path") = ""); m.def( "load", @@ -554,7 +689,15 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) a.fill(values.begin(), values.end()); return a; }); - m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); + + m.def( + "generate_argument", + [](const migraphx::shape& s, unsigned long seed) { + return migraphx::generate_argument(s, seed); + }, + py::arg("s"), + py::arg("seed") = 0); + m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value")); m.def("quantize_fp16", &migraphx::quantize_fp16, @@ -565,7 +708,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("prog"), py::arg("t"), py::arg("calibration") = std::vector{}, - py::arg("ins_names") = std::vector{"dot", "convolution"}); + py::arg("ins_names") = std::unordered_set{"dot", "convolution"}); + m.def("quantize_fp8", + &migraphx::quantize_fp8, + py::arg("prog"), + py::arg("t"), + py::arg("calibration") = std::vector{}); + m.def( + "autocast_fp8", + [](migraphx::program& prog) { + migraphx::run_passes(*prog.get_main_module(), {migraphx::autocast_fp8_pass{}}); + }, + "Auto-convert FP8 parameters and return values to Float for MIGraphX Program", + py::arg("prog")); + m.def("quantize_bf16", + &migraphx::quantize_bf16, + py::arg("prog"), + py::arg("ins_names") = std::vector{"all"}); #ifdef HAVE_GPU m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); @@ -577,6 +736,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO; #else - m.attr("__version__") = "dev"; + auto version_string = std::to_string(MIGRAPHX_VERSION_MAJOR) + "." + + std::to_string(MIGRAPHX_VERSION_MINOR) + "." + + std::to_string(MIGRAPHX_VERSION_PATCH) + ".dev"; + + std::string tweak(MIGRAPHX_VERSION_TWEAK); + if(not tweak.empty()) + version_string += "+" + tweak; + + m.attr("__version__") = version_string; #endif }