mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 17:14:26 +03:00
Update migraphx_py.cpp
This commit is contained in:
parent
c031a9185f
commit
b0d9db1f4d
@ -1,7 +1,7 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
@ -25,12 +25,10 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/quantization.hpp>
|
||||
#include <migraphx/autocast_fp8.hpp>
|
||||
#include <migraphx/generate.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/ref/target.hpp>
|
||||
@ -42,15 +40,12 @@
|
||||
#include <migraphx/json.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/op/common.hpp>
|
||||
#include <migraphx/float8.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/version.h>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
|
||||
#ifdef HAVE_GPU
|
||||
#include <migraphx/gpu/hip.hpp>
|
||||
#endif
|
||||
|
||||
using half = migraphx::half;
|
||||
using half = half_float::half;
|
||||
namespace py = pybind11;
|
||||
|
||||
#ifdef __clang__
|
||||
@ -149,61 +144,6 @@ struct npy_format_descriptor<half>
|
||||
static constexpr auto name() { return _("half"); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
|
||||
{
|
||||
static std::string format()
|
||||
{
|
||||
// TODO: no standard format in numpy for fp8
|
||||
return "z";
|
||||
}
|
||||
static constexpr auto name() { return _("fp8e4m3fnuz"); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_format_descriptor<migraphx::fp8::fp8e5m2fnuz>
|
||||
{
|
||||
static std::string format()
|
||||
{
|
||||
// TODO: no standard format in numpy for fp8
|
||||
return "z";
|
||||
}
|
||||
static constexpr auto name() { return _("fp8e5m2fnuz"); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fn>
|
||||
{
|
||||
static std::string format()
|
||||
{
|
||||
// TODO: no standard format in numpy for fp8
|
||||
return "z";
|
||||
}
|
||||
static constexpr auto name() { return _("fp8e4m3fn"); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_format_descriptor<migraphx::fp8::fp8e5m2>
|
||||
{
|
||||
static std::string format()
|
||||
{
|
||||
// TODO: no standard format in numpy for fp8
|
||||
return "z";
|
||||
}
|
||||
static constexpr auto name() { return _("fp8e5m2"); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_format_descriptor<migraphx::bf16>
|
||||
{
|
||||
static std::string format()
|
||||
{
|
||||
// TODO: no standard format in numpy for bf16
|
||||
return "z";
|
||||
}
|
||||
static constexpr auto name() { return _("bf16"); }
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
@ -261,52 +201,10 @@ py::buffer_info to_buffer_info(T& x)
|
||||
return b;
|
||||
}
|
||||
|
||||
py::object to_py_object(const migraphx::value& val)
|
||||
{
|
||||
py::object result;
|
||||
|
||||
val.visit_value([&](const auto& x) {
|
||||
if constexpr(std::is_same<std::decay_t<decltype(x)>, std::vector<migraphx::value>>{})
|
||||
{
|
||||
if(val.is_object())
|
||||
{
|
||||
py::dict py_dict;
|
||||
for(const auto& item : x)
|
||||
{
|
||||
py_dict[py::str(item.get_key())] = to_py_object(item.without_key());
|
||||
}
|
||||
result = py_dict;
|
||||
}
|
||||
else
|
||||
{
|
||||
py::list py_list;
|
||||
for(const auto& item : x)
|
||||
{
|
||||
py_list.append(to_py_object(item));
|
||||
}
|
||||
result = py_list;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result = py::cast(x);
|
||||
}
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
migraphx::shape to_shape(const py::buffer_info& info)
|
||||
{
|
||||
migraphx::shape::type_t t;
|
||||
std::size_t n = 0;
|
||||
// Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum
|
||||
if(info.format == "z")
|
||||
{
|
||||
MIGRAPHX_THROW(
|
||||
"MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using "
|
||||
"migraphx.generate_argument with migraphx.add_literal");
|
||||
}
|
||||
visit_types([&](auto as) {
|
||||
if(info.format == py::format_descriptor<decltype(as())>::format() or
|
||||
(info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
|
||||
@ -417,12 +315,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
|
||||
py::class_<migraphx::instruction_ref>(m, "instruction_ref")
|
||||
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
|
||||
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); })
|
||||
.def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); })
|
||||
.def("name", [](migraphx::instruction_ref i) { return i->name(); })
|
||||
.def(py::hash(py::self))
|
||||
.def(py::self == py::self)
|
||||
.def(py::self != py::self);
|
||||
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
|
||||
|
||||
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
|
||||
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
|
||||
@ -437,12 +330,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
py::arg("op"),
|
||||
py::arg("args"),
|
||||
py::arg("mod_args") = std::vector<migraphx::module*>{})
|
||||
.def(
|
||||
"add_literal",
|
||||
[](migraphx::module& mm, migraphx::argument a) {
|
||||
return mm.add_literal(a.get_shape(), a.data());
|
||||
},
|
||||
py::arg("data"))
|
||||
.def(
|
||||
"add_literal",
|
||||
[](migraphx::module& mm, py::buffer data) {
|
||||
@ -464,14 +351,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
return mm.add_return(args);
|
||||
},
|
||||
py::arg("args"))
|
||||
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); })
|
||||
.def(
|
||||
"__iter__",
|
||||
[](const migraphx::module& mm) {
|
||||
auto r = migraphx::iterator_for(mm);
|
||||
return py::make_iterator(r.begin(), r.end());
|
||||
},
|
||||
py::keep_alive<0, 1>());
|
||||
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
|
||||
|
||||
py::class_<migraphx::program>(m, "program")
|
||||
.def(py::init([]() { return migraphx::program(); }))
|
||||
@ -530,12 +410,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
|
||||
return p.eval(pm, exec_env);
|
||||
})
|
||||
.def("to_py",
|
||||
[](const migraphx::program& p) {
|
||||
std::stringstream ss;
|
||||
p.print_py(ss);
|
||||
return ss.str();
|
||||
})
|
||||
.def("sort", &migraphx::program::sort)
|
||||
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
|
||||
.def("__eq__", std::equal_to<migraphx::program>{})
|
||||
@ -551,10 +425,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
}
|
||||
return migraphx::make_op(name, v);
|
||||
}))
|
||||
.def("name", &migraphx::operation::name)
|
||||
.def("values", [](const migraphx::operation& operation) -> py::object {
|
||||
return to_py_object(operation.to_value());
|
||||
});
|
||||
.def("name", &migraphx::operation::name);
|
||||
|
||||
py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
|
||||
.value("average", migraphx::op::pooling_mode::average)
|
||||
@ -601,8 +472,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
map_dyn_input_dims,
|
||||
bool skip_unknown_operators,
|
||||
bool print_program_on_error,
|
||||
int64_t max_loop_iterations,
|
||||
int64_t limit_max_iterations) {
|
||||
int64_t max_loop_iterations) {
|
||||
migraphx::onnx_options options;
|
||||
options.default_dim_value = default_dim_value;
|
||||
options.default_dyn_dim_value = default_dyn_dim_value;
|
||||
@ -611,7 +481,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
options.skip_unknown_operators = skip_unknown_operators;
|
||||
options.print_program_on_error = print_program_on_error;
|
||||
options.max_loop_iterations = max_loop_iterations;
|
||||
options.limit_max_iterations = limit_max_iterations;
|
||||
return migraphx::parse_onnx(filename, options);
|
||||
},
|
||||
"Parse onnx file",
|
||||
@ -623,8 +492,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
|
||||
py::arg("skip_unknown_operators") = false,
|
||||
py::arg("print_program_on_error") = false,
|
||||
py::arg("max_loop_iterations") = 10,
|
||||
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
|
||||
py::arg("max_loop_iterations") = 10);
|
||||
|
||||
m.def(
|
||||
"parse_onnx_buffer",
|
||||
@ -635,8 +503,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
|
||||
map_dyn_input_dims,
|
||||
bool skip_unknown_operators,
|
||||
bool print_program_on_error,
|
||||
const std::string& external_data_path) {
|
||||
bool print_program_on_error) {
|
||||
migraphx::onnx_options options;
|
||||
options.default_dim_value = default_dim_value;
|
||||
options.default_dyn_dim_value = default_dyn_dim_value;
|
||||
@ -644,7 +511,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
options.map_dyn_input_dims = map_dyn_input_dims;
|
||||
options.skip_unknown_operators = skip_unknown_operators;
|
||||
options.print_program_on_error = print_program_on_error;
|
||||
options.external_data_path = external_data_path;
|
||||
return migraphx::parse_onnx_buffer(onnx_buffer, options);
|
||||
},
|
||||
"Parse onnx file",
|
||||
@ -655,8 +521,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
py::arg("map_dyn_input_dims") =
|
||||
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
|
||||
py::arg("skip_unknown_operators") = false,
|
||||
py::arg("print_program_on_error") = false,
|
||||
py::arg("external_data_path") = "");
|
||||
py::arg("print_program_on_error") = false);
|
||||
|
||||
m.def(
|
||||
"load",
|
||||
@ -689,15 +554,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
a.fill(values.begin(), values.end());
|
||||
return a;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"generate_argument",
|
||||
[](const migraphx::shape& s, unsigned long seed) {
|
||||
return migraphx::generate_argument(s, seed);
|
||||
},
|
||||
py::arg("s"),
|
||||
py::arg("seed") = 0);
|
||||
|
||||
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
|
||||
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
|
||||
m.def("quantize_fp16",
|
||||
&migraphx::quantize_fp16,
|
||||
@ -708,23 +565,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
py::arg("prog"),
|
||||
py::arg("t"),
|
||||
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
|
||||
py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
|
||||
m.def("quantize_fp8",
|
||||
&migraphx::quantize_fp8,
|
||||
py::arg("prog"),
|
||||
py::arg("t"),
|
||||
py::arg("calibration") = std::vector<migraphx::parameter_map>{});
|
||||
m.def(
|
||||
"autocast_fp8",
|
||||
[](migraphx::program& prog) {
|
||||
migraphx::run_passes(*prog.get_main_module(), {migraphx::autocast_fp8_pass{}});
|
||||
},
|
||||
"Auto-convert FP8 parameters and return values to Float for MIGraphX Program",
|
||||
py::arg("prog"));
|
||||
m.def("quantize_bf16",
|
||||
&migraphx::quantize_bf16,
|
||||
py::arg("prog"),
|
||||
py::arg("ins_names") = std::vector<std::string>{"all"});
|
||||
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
|
||||
|
||||
#ifdef HAVE_GPU
|
||||
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
|
||||
@ -736,14 +577,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
#else
|
||||
auto version_string = std::to_string(MIGRAPHX_VERSION_MAJOR) + "." +
|
||||
std::to_string(MIGRAPHX_VERSION_MINOR) + "." +
|
||||
std::to_string(MIGRAPHX_VERSION_PATCH) + ".dev";
|
||||
|
||||
std::string tweak(MIGRAPHX_VERSION_TWEAK);
|
||||
if(not tweak.empty())
|
||||
version_string += "+" + tweak;
|
||||
|
||||
m.attr("__version__") = version_string;
|
||||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user