/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { using fp8::fp8e4m3fnuz; std::unordered_set get_quantizable_op_names() { static std::unordered_set s = {"convolution", "dot"}; return s; } struct match_fp8ocp_convert_to_fp8fnuz { auto matcher() const { auto dq1 = match::arg(0)( skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1"))); auto dq2 = match::arg(1)( skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } static auto bit_cast_and_handle_specials(module& m, const instruction_ref dq, const instruction_ref x, const instruction_ref bits_0x80_lit, const instruction_ref bits_0x7f_lit, const instruction_ref bits_0xff_lit, const instruction_ref bits_0x00_lit) { auto x_lens = x->get_shape().lens(); auto cast_input = m.insert_instruction( dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); auto mb_bits_0x80_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit); auto mb_bits_0x7f_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit); auto mb_bits_0xff_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit); auto mb_zero_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit); // negative zero in fp8e4m3fn to zero in fp8e4m3fnuz // a == 0x80 ? 0x0 : a auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit); auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input); // positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz // (a == 0x7f or a == 0xff) ? 0x80 : a auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit); auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit); auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff); ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret); return ret; } // Add the same broadcast instructions after adjusted scales or // adjusted zero points from after the originals. Similar to // propagate_quantized_ins in simplify_qdq. static auto propagate_broadcasts(module& m, const instruction_ref adj, const instruction_ref ori, const instruction_ref start, const instruction_ref insert_pt) { auto prev_ins = start; std::vector ins_between; // matcher skips continguous, multi/broadcasts and transposes, collect all those // instructions while(prev_ins != ori) { ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } auto ret = adj; for(auto ins : reverse_iterator_for(ins_between)) { ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret}); } return ret; } static auto cast_to_fnuz(module& m, const instruction_ref dq, const instruction_ref input, const instruction_ref dq_scale, const instruction_ref dq_zp) { auto x = input; std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80); std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f); std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff); std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00); x = bit_cast_and_handle_specials( m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); auto adj_dq_zp = bit_cast_and_handle_specials( m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); // adj_scale = 2 * scale auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}}); two_lit = m.insert_instruction( dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq); adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq); m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp); } auto apply(module& m, const match::matcher_result& r) const { auto dq1 = r.instructions["dq1"]; auto dq2 = r.instructions["dq2"]; auto scale1 = r.instructions["scale1"]; auto scale2 = r.instructions["scale2"]; auto zp1 = r.instructions["zp1"]; auto zp2 = r.instructions["zp2"]; std::set supported_types = {migraphx::shape::fp8e4m3fn_type}; if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or not contains(supported_types, dq2->inputs().front()->get_shape().type())) return; cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1); cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2); } }; } // namespace void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const { module_ref mm = &mpm.get_module(); match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{}); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx