mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 17:14:26 +03:00
upload
This commit is contained in:
parent
58e9831aef
commit
7eefb89bf6
46
docker/rocm/migraphx/api/CMakeLists.txt
Normal file
46
docker/rocm/migraphx/api/CMakeLists.txt
Normal file
@ -0,0 +1,46 @@
|
||||
#####################################################################################
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#####################################################################################
|
||||
|
||||
add_library(migraphx_c
|
||||
api.cpp
|
||||
)
|
||||
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
|
||||
migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
|
||||
|
||||
# migraphx_c is stable API interface library. SO version of this should be
|
||||
# bumped when binary compatibility is broken.
|
||||
rocm_set_soversion(migraphx_c 3.0)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
target_compile_definitions(migraphx_c PRIVATE MIGRAPHX_BUILD_TESTING)
|
||||
endif()
|
||||
|
||||
rocm_clang_tidy_check(migraphx_c)
|
||||
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx)
|
||||
|
||||
rocm_install_targets(
|
||||
TARGETS migraphx_c
|
||||
INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
)
|
||||
2442
docker/rocm/migraphx/api/api.cpp
Normal file
2442
docker/rocm/migraphx/api/api.cpp
Normal file
File diff suppressed because it is too large
Load Diff
684
docker/rocm/migraphx/api/include/migraphx/migraphx.h
Normal file
684
docker/rocm/migraphx/api/include/migraphx/migraphx.h
Normal file
@ -0,0 +1,684 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H
|
||||
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <migraphx/api/export.h>
|
||||
|
||||
// Add new types here
|
||||
// clang-format off
|
||||
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
|
||||
m(bool_type, bool) \
|
||||
m(half_type, half) \
|
||||
m(float_type, float) \
|
||||
m(double_type, double) \
|
||||
m(uint8_type, uint8_t) \
|
||||
m(int8_type, int8_t) \
|
||||
m(uint16_type, uint16_t) \
|
||||
m(int16_type, int16_t) \
|
||||
m(int32_type, int32_t) \
|
||||
m(int64_type, int64_t) \
|
||||
m(uint32_type, uint32_t) \
|
||||
m(uint64_type, uint64_t) \
|
||||
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
|
||||
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
|
||||
m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \
|
||||
m(bf16_type, bf16) \
|
||||
m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz)
|
||||
// clang-format on
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// return code, more to be added later
|
||||
typedef enum
|
||||
{
|
||||
migraphx_status_success = 0,
|
||||
migraphx_status_bad_param = 1,
|
||||
migraphx_status_unknown_target = 3,
|
||||
migraphx_status_unknown_error = 4,
|
||||
|
||||
} migraphx_status;
|
||||
|
||||
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
|
||||
/// An enum to represent the different data type inputs
|
||||
typedef enum
|
||||
{
|
||||
migraphx_shape_tuple_type,
|
||||
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
|
||||
} migraphx_shape_datatype_t;
|
||||
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
|
||||
|
||||
typedef struct migraphx_optimals* migraphx_optimals_t;
|
||||
typedef const struct migraphx_optimals* const_migraphx_optimals_t;
|
||||
|
||||
typedef struct migraphx_dynamic_dimension* migraphx_dynamic_dimension_t;
|
||||
typedef const struct migraphx_dynamic_dimension* const_migraphx_dynamic_dimension_t;
|
||||
|
||||
typedef struct migraphx_dynamic_dimensions* migraphx_dynamic_dimensions_t;
|
||||
typedef const struct migraphx_dynamic_dimensions* const_migraphx_dynamic_dimensions_t;
|
||||
|
||||
typedef struct migraphx_shape* migraphx_shape_t;
|
||||
typedef const struct migraphx_shape* const_migraphx_shape_t;
|
||||
|
||||
typedef struct migraphx_argument* migraphx_argument_t;
|
||||
typedef const struct migraphx_argument* const_migraphx_argument_t;
|
||||
|
||||
typedef struct migraphx_target* migraphx_target_t;
|
||||
typedef const struct migraphx_target* const_migraphx_target_t;
|
||||
|
||||
typedef struct migraphx_program_parameter_shapes* migraphx_program_parameter_shapes_t;
|
||||
typedef const struct migraphx_program_parameter_shapes* const_migraphx_program_parameter_shapes_t;
|
||||
|
||||
typedef struct migraphx_program_parameters* migraphx_program_parameters_t;
|
||||
typedef const struct migraphx_program_parameters* const_migraphx_program_parameters_t;
|
||||
|
||||
typedef struct migraphx_arguments* migraphx_arguments_t;
|
||||
typedef const struct migraphx_arguments* const_migraphx_arguments_t;
|
||||
|
||||
typedef struct migraphx_shapes* migraphx_shapes_t;
|
||||
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
|
||||
|
||||
typedef struct migraphx_instruction* migraphx_instruction_t;
|
||||
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
|
||||
|
||||
typedef struct migraphx_instructions* migraphx_instructions_t;
|
||||
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
|
||||
|
||||
typedef struct migraphx_modules* migraphx_modules_t;
|
||||
typedef const struct migraphx_modules* const_migraphx_modules_t;
|
||||
|
||||
typedef struct migraphx_module* migraphx_module_t;
|
||||
typedef const struct migraphx_module* const_migraphx_module_t;
|
||||
|
||||
typedef struct migraphx_program* migraphx_program_t;
|
||||
typedef const struct migraphx_program* const_migraphx_program_t;
|
||||
|
||||
typedef struct migraphx_operation* migraphx_operation_t;
|
||||
typedef const struct migraphx_operation* const_migraphx_operation_t;
|
||||
|
||||
typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
|
||||
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
|
||||
|
||||
typedef struct migraphx_file_options* migraphx_file_options_t;
|
||||
typedef const struct migraphx_file_options* const_migraphx_file_options_t;
|
||||
|
||||
typedef struct migraphx_compile_options* migraphx_compile_options_t;
|
||||
typedef const struct migraphx_compile_options* const_migraphx_compile_options_t;
|
||||
|
||||
typedef struct migraphx_tf_options* migraphx_tf_options_t;
|
||||
typedef const struct migraphx_tf_options* const_migraphx_tf_options_t;
|
||||
|
||||
typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t;
|
||||
typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t;
|
||||
|
||||
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
|
||||
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
|
||||
|
||||
typedef struct migraphx_quantize_fp8_options* migraphx_quantize_fp8_options_t;
|
||||
typedef const struct migraphx_quantize_fp8_options* const_migraphx_quantize_fp8_options_t;
|
||||
|
||||
typedef struct migraphx_context* migraphx_context_t;
|
||||
typedef const struct migraphx_context* const_migraphx_context_t;
|
||||
|
||||
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
|
||||
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
|
||||
void* obj,
|
||||
char* exception_msg,
|
||||
size_t exception_msg_size,
|
||||
migraphx_context_t ctx,
|
||||
migraphx_shape_t output,
|
||||
migraphx_arguments_t inputs);
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
|
||||
void* obj,
|
||||
char* exception_msg,
|
||||
size_t exception_msg_size,
|
||||
migraphx_shapes_t inputs);
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_output_alias)(size_t* out,
|
||||
size_t* out_size,
|
||||
void* obj,
|
||||
char* exception_msg,
|
||||
size_t exception_msg_size,
|
||||
migraphx_shapes_t inputs);
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_runs_on_offload_target)(
|
||||
bool* out, void* obj, char* exception_msg, size_t exception_msg_size);
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
|
||||
|
||||
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
|
||||
const_migraphx_optimals_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_create(migraphx_optimals_t* optimals,
|
||||
const size_t* ptr,
|
||||
size_t size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_assign_to(
|
||||
migraphx_dynamic_dimension_t output, const_migraphx_dynamic_dimension_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_create_min_max(
|
||||
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
|
||||
size_t min,
|
||||
size_t max,
|
||||
migraphx_optimals_t optimals);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_is_fixed(
|
||||
bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimension_equal(bool* out,
|
||||
const_migraphx_dynamic_dimension_t dynamic_dimension,
|
||||
const_migraphx_dynamic_dimension_t x);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
|
||||
migraphx_dynamic_dimensions_t output, const_migraphx_dynamic_dimensions_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
|
||||
const const_migraphx_dynamic_dimension_t* ptr,
|
||||
size_t size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
|
||||
migraphx_dynamic_dimensions_t dynamic_dimensions,
|
||||
size_t idx);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
|
||||
const_migraphx_shape_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
|
||||
migraphx_shape_datatype_t type,
|
||||
size_t* lengths,
|
||||
size_t lengths_size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
|
||||
migraphx_shape_datatype_t type,
|
||||
size_t* lengths,
|
||||
size_t lengths_size,
|
||||
size_t* strides,
|
||||
size_t strides_size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
|
||||
migraphx_shape_datatype_t type);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
|
||||
migraphx_shape_datatype_t type,
|
||||
migraphx_dynamic_dimensions_t dims);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_lengths(const size_t** out,
|
||||
size_t* out_size,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_strides(const size_t** out,
|
||||
size_t* out_size,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_elements(size_t* out,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_equal(bool* out,
|
||||
const_migraphx_shape_t shape,
|
||||
const_migraphx_shape_t x);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_index(size_t* out,
|
||||
const_migraphx_shape_t shape,
|
||||
size_t i);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
|
||||
const_migraphx_argument_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create(migraphx_argument_t* argument,
|
||||
const_migraphx_shape_t shape,
|
||||
void* buffer);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
|
||||
const_migraphx_argument_t argument);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_buffer(char** out,
|
||||
const_migraphx_argument_t argument);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_equal(bool* out,
|
||||
const_migraphx_argument_t argument,
|
||||
const_migraphx_argument_t x);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_generate(migraphx_argument_t* out,
|
||||
const_migraphx_shape_t s,
|
||||
size_t seed);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_destroy(migraphx_target_t target);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_assign_to(migraphx_target_t output,
|
||||
const_migraphx_target_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_create(migraphx_target_t* target,
|
||||
const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_destroy(
|
||||
migraphx_program_parameter_shapes_t program_parameter_shapes);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_assign_to(
|
||||
migraphx_program_parameter_shapes_t output, const_migraphx_program_parameter_shapes_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_size(
|
||||
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
|
||||
migraphx_program_parameter_shapes_t program_parameter_shapes,
|
||||
const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_names(
|
||||
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameters_assign_to(
|
||||
migraphx_program_parameters_t output, const_migraphx_program_parameters_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
|
||||
const char* name,
|
||||
const_migraphx_argument_t argument);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
|
||||
const_migraphx_arguments_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_size(size_t* out,
|
||||
migraphx_arguments_t arguments);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_get(const_migraphx_argument_t* out,
|
||||
migraphx_arguments_t arguments,
|
||||
size_t idx);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
|
||||
const_migraphx_shapes_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_get(const_migraphx_shape_t* out,
|
||||
migraphx_shapes_t shapes,
|
||||
size_t idx);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_instruction_assign_to(migraphx_instruction_t output, const_migraphx_instruction_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_instructions_destroy(migraphx_instructions_t instructions);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
|
||||
migraphx_instructions_t output, const_migraphx_instructions_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_create(
|
||||
migraphx_instructions_t* instructions, const const_migraphx_instruction_t* ptr, size_t size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
|
||||
const_migraphx_modules_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_create(migraphx_modules_t* modules,
|
||||
migraphx_module_t* ptr,
|
||||
size_t size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_print(const_migraphx_module_t module);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
migraphx_operation_t op,
|
||||
migraphx_instructions_t args);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
migraphx_operation_t op,
|
||||
migraphx_instructions_t args,
|
||||
migraphx_modules_t module_refs);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
const_migraphx_shape_t shape,
|
||||
const char* buffer);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
const char* name,
|
||||
const_migraphx_shape_t shape);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
migraphx_instructions_t args);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
|
||||
migraphx_module_t module,
|
||||
const_migraphx_shape_t s);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_destroy(migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_assign_to(migraphx_program_t output,
|
||||
const_migraphx_program_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create(migraphx_program_t* program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
|
||||
migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create_module(migraphx_module_t* out,
|
||||
migraphx_program_t program,
|
||||
const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_compile(migraphx_program_t program,
|
||||
migraphx_target_t target,
|
||||
migraphx_compile_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_parameter_shapes(
|
||||
migraphx_program_parameter_shapes_t* out, migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
|
||||
migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_print(const_migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_sort(migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run(migraphx_arguments_t* out,
|
||||
migraphx_program_t program,
|
||||
migraphx_program_parameters_t params);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
|
||||
migraphx_program_t program,
|
||||
migraphx_program_parameters_t params,
|
||||
void* s,
|
||||
const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_equal(bool* out,
|
||||
const_migraphx_program_t program,
|
||||
const_migraphx_program_t x);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_experimental_get_context(
|
||||
migraphx_context_t* out, const_migraphx_program_t program);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
|
||||
const_migraphx_operation_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
|
||||
const char* name,
|
||||
const char* attributes,
|
||||
...);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_name(char* out,
|
||||
size_t out_size,
|
||||
migraphx_operation_t operation);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_load(migraphx_program_t* out,
|
||||
const char* name,
|
||||
migraphx_file_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_save(migraphx_program_t p,
|
||||
const char* name,
|
||||
migraphx_file_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_assign_to(
|
||||
migraphx_onnx_options_t output, const_migraphx_onnx_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_input_parameter_shape(
|
||||
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
|
||||
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_value(
|
||||
migraphx_onnx_options_t onnx_options, const_migraphx_dynamic_dimension_t dd);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
|
||||
migraphx_onnx_options_t onnx_options, int64_t value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_limit_loop_iterations(
|
||||
migraphx_onnx_options_t onnx_options, int64_t value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_external_data_path(
|
||||
migraphx_onnx_options_t onnx_options, const char* external_data_path);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_file_options_destroy(migraphx_file_options_t file_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_file_options_assign_to(
|
||||
migraphx_file_options_t output, const_migraphx_file_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_file_options_create(migraphx_file_options_t* file_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_assign_to(
|
||||
migraphx_compile_options_t output, const_migraphx_compile_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_set_exhaustive_tune_flag(
|
||||
migraphx_compile_options_t compile_options, bool value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx(migraphx_program_t* out,
|
||||
const char* name,
|
||||
migraphx_onnx_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
|
||||
const void* data,
|
||||
size_t size,
|
||||
migraphx_onnx_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
|
||||
const_migraphx_tf_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
|
||||
bool is_nhwc);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_input_parameter_shape(
|
||||
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_output_names(
|
||||
migraphx_tf_options_t tf_options, const char** names, size_t names_size);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_tf(migraphx_program_t* out,
|
||||
const char* name,
|
||||
migraphx_tf_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_op_names_assign_to(
|
||||
migraphx_quantize_op_names_t output, const_migraphx_quantize_op_names_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, migraphx_quantize_op_names_t name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_assign_to(
|
||||
migraphx_quantize_int8_options_t output, const_migraphx_quantize_int8_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_op_name(
|
||||
migraphx_quantize_int8_options_t quantize_int8_options, const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_calibration_data(
|
||||
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
|
||||
migraphx_target_t target,
|
||||
migraphx_quantize_int8_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_fp8_options_destroy(migraphx_quantize_fp8_options_t quantize_fp8_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_assign_to(
|
||||
migraphx_quantize_fp8_options_t output, const_migraphx_quantize_fp8_options_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_quantize_fp8_options_create(migraphx_quantize_fp8_options_t* quantize_fp8_options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_add_calibration_data(
|
||||
migraphx_quantize_fp8_options_t quantize_fp8_options, migraphx_program_parameters_t data);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8(migraphx_program_t prog,
|
||||
migraphx_target_t target,
|
||||
migraphx_quantize_fp8_options_t options);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out,
|
||||
migraphx_context_t context);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_assign_to(
|
||||
migraphx_experimental_custom_op_t output, const_migraphx_experimental_custom_op_t input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
|
||||
void* obj,
|
||||
migraphx_experimental_custom_op_copy c,
|
||||
migraphx_experimental_custom_op_delete d,
|
||||
const char* obj_typename,
|
||||
const char* name);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute(
|
||||
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute_shape(
|
||||
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_output_alias(
|
||||
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
|
||||
migraphx_experimental_custom_op_t obj,
|
||||
migraphx_experimental_custom_op_runs_on_offload_target input);
|
||||
|
||||
MIGRAPHX_C_EXPORT migraphx_status
|
||||
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
1589
docker/rocm/migraphx/api/include/migraphx/migraphx.hpp
Normal file
1589
docker/rocm/migraphx/api/include/migraphx/migraphx.hpp
Normal file
File diff suppressed because it is too large
Load Diff
510
docker/rocm/migraphx/api/migraphx.py
Normal file
510
docker/rocm/migraphx/api/migraphx.py
Normal file
@ -0,0 +1,510 @@
|
||||
#####################################################################################
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#####################################################################################
|
||||
import api
|
||||
|
||||
|
||||
def bad_param_error(msg):
|
||||
return 'MIGRAPHX_THROW(migraphx_status_bad_param, "{}")'.format(msg)
|
||||
|
||||
|
||||
api.error_type = 'migraphx_status'
|
||||
api.success_type = 'migraphx_status_success'
|
||||
api.try_wrap = 'migraphx::try_'
|
||||
api.bad_param_error = bad_param_error
|
||||
|
||||
|
||||
@api.cwrap('migraphx::shape::type_t')
|
||||
def shape_type_wrap(p):
|
||||
if p.returns:
|
||||
p.add_param('migraphx_shape_datatype_t *')
|
||||
p.bad_param('${name} == nullptr', 'Null pointer')
|
||||
p.write = ['*${name} = migraphx::to_shape_type(${result})']
|
||||
else:
|
||||
p.add_param('migraphx_shape_datatype_t')
|
||||
p.read = 'migraphx::to_shape_type(${name})'
|
||||
|
||||
|
||||
def auto_handle(*args, **kwargs):
|
||||
def with_handle(f):
|
||||
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
|
||||
*args, **kwargs)(f)
|
||||
|
||||
return with_handle
|
||||
|
||||
|
||||
@api.handle('migraphx_optimals', 'std::set<size_t>')
|
||||
def optimals(h):
|
||||
h.constructor('create',
|
||||
api.params(ptr='const size_t*', size='size_t'),
|
||||
fname='migraphx::make_set<size_t>')
|
||||
|
||||
|
||||
@api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
|
||||
def dynamic_dimension(h):
|
||||
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
|
||||
h.constructor(
|
||||
'create_min_max_optimals',
|
||||
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
|
||||
h.method('is_fixed', returns='bool', const=True)
|
||||
h.method('equal',
|
||||
api.params(x='const migraphx::shape::dynamic_dimension&'),
|
||||
invoke='migraphx::equal($@)',
|
||||
returns='bool',
|
||||
const=True)
|
||||
|
||||
|
||||
@api.handle('migraphx_dynamic_dimensions',
|
||||
'std::vector<migraphx::shape::dynamic_dimension>')
|
||||
def dynamic_dimensions(h):
|
||||
h.constructor(
|
||||
'create',
|
||||
api.params(ptr='const const_migraphx_dynamic_dimension_t*',
|
||||
size='size_t'),
|
||||
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
|
||||
h.method('size', returns='size_t')
|
||||
h.method('get',
|
||||
api.params(idx='size_t'),
|
||||
fname='at',
|
||||
cpp_name='operator[]',
|
||||
returns='const migraphx::shape::dynamic_dimension&')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def shape(h):
|
||||
h.constructor(
|
||||
'create',
|
||||
api.params(type='migraphx::shape::type_t',
|
||||
lengths='std::vector<size_t>'))
|
||||
h.constructor(
|
||||
'create_with_strides',
|
||||
api.params(type='migraphx::shape::type_t',
|
||||
lengths='std::vector<size_t>',
|
||||
strides='std::vector<size_t>'))
|
||||
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
|
||||
h.constructor(
|
||||
'create_dynamic',
|
||||
api.params(type='migraphx::shape::type_t',
|
||||
dims='std::vector<migraphx::shape::dynamic_dimension>'))
|
||||
h.method('lengths',
|
||||
fname='lens',
|
||||
returns='const std::vector<size_t>&',
|
||||
const=True)
|
||||
h.method('strides', returns='const std::vector<size_t>&', const=True)
|
||||
h.method('dyn_dims',
|
||||
returns='std::vector<migraphx::shape::dynamic_dimension>',
|
||||
const=True)
|
||||
h.method('type', returns='migraphx::shape::type_t', const=True)
|
||||
h.method('elements', returns='size_t', const=True)
|
||||
h.method('bytes', returns='size_t', const=True)
|
||||
h.method('ndim', returns='size_t', const=True)
|
||||
h.method('equal',
|
||||
api.params(x='const migraphx::shape&'),
|
||||
invoke='migraphx::equal($@)',
|
||||
returns='bool',
|
||||
const=True)
|
||||
h.method('standard', returns='bool', const=True)
|
||||
h.method('dynamic', returns='bool', const=True)
|
||||
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def argument(h):
|
||||
h.constructor('create',
|
||||
api.params(shape='const migraphx::shape&', buffer='void*'))
|
||||
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
|
||||
h.method('shape',
|
||||
fname='get_shape',
|
||||
cpp_name='get_shape',
|
||||
returns='const migraphx::shape&',
|
||||
const=True)
|
||||
h.method('buffer',
|
||||
fname='data',
|
||||
cpp_name='data',
|
||||
returns='char*',
|
||||
const=True)
|
||||
h.method('equal',
|
||||
api.params(x='const migraphx::argument&'),
|
||||
invoke='migraphx::equal($@)',
|
||||
returns='bool',
|
||||
const=True)
|
||||
|
||||
|
||||
api.add_function('migraphx_argument_generate',
|
||||
api.params(s='const migraphx::shape&', seed='size_t'),
|
||||
fname='migraphx::generate_argument',
|
||||
returns='migraphx::argument')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def target(h):
|
||||
h.constructor('create',
|
||||
api.params(name='const char*'),
|
||||
fname='migraphx::get_target')
|
||||
|
||||
|
||||
@api.handle('migraphx_program_parameter_shapes',
|
||||
'std::unordered_map<std::string, migraphx::shape>')
|
||||
def program_parameter_shapes(h):
|
||||
h.method('size', returns='size_t')
|
||||
h.method('get',
|
||||
api.params(name='const char*'),
|
||||
fname='at',
|
||||
cpp_name='operator[]',
|
||||
returns='const migraphx::shape&')
|
||||
h.method('names',
|
||||
invoke='migraphx::get_names(${program_parameter_shapes})',
|
||||
returns='std::vector<const char*>')
|
||||
|
||||
|
||||
@api.handle('migraphx_program_parameters',
|
||||
'std::unordered_map<std::string, migraphx::argument>')
|
||||
def program_parameters(h):
|
||||
h.constructor('create')
|
||||
h.method('add',
|
||||
api.params(name='const char*',
|
||||
argument='const migraphx::argument&'),
|
||||
invoke='${program_parameters}[${name}] = ${argument}')
|
||||
|
||||
|
||||
@api.handle('migraphx_arguments', 'std::vector<migraphx::argument>')
|
||||
def arguments(h):
|
||||
h.method('size', returns='size_t')
|
||||
h.method('get',
|
||||
api.params(idx='size_t'),
|
||||
fname='at',
|
||||
cpp_name='operator[]',
|
||||
returns='const migraphx::argument&')
|
||||
|
||||
|
||||
@api.handle('migraphx_shapes', 'std::vector<migraphx::shape>')
|
||||
def shapes(h):
|
||||
h.method('size', returns='size_t')
|
||||
h.method('get',
|
||||
api.params(idx='size_t'),
|
||||
fname='at',
|
||||
cpp_name='operator[]',
|
||||
returns='const migraphx::shape&')
|
||||
|
||||
|
||||
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
|
||||
def instruction(h):
|
||||
pass
|
||||
|
||||
|
||||
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
|
||||
def instructions(h):
|
||||
h.constructor(
|
||||
'create',
|
||||
api.params(ptr='const const_migraphx_instruction_t*', size='size_t'),
|
||||
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
|
||||
|
||||
|
||||
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
|
||||
def modules(h):
|
||||
h.constructor('create',
|
||||
api.params(ptr='migraphx_module_t*', size='size_t'),
|
||||
fname='migraphx::to_objptr_vector<migraphx::module*>')
|
||||
|
||||
|
||||
@auto_handle(ref=True)
|
||||
def module(h):
|
||||
h.constructor('create', api.params(name='std::string'))
|
||||
h.method('print', invoke='migraphx::print_module($@)', const=True)
|
||||
h.method('add_instruction',
|
||||
api.params(op='migraphx::operation',
|
||||
args='std::vector<migraphx::instruction_ref>'),
|
||||
returns='migraphx::instruction_ref')
|
||||
h.method('add_instruction_with_mod_args',
|
||||
api.params(op='migraphx::operation',
|
||||
args='std::vector<migraphx::instruction_ref>',
|
||||
module_refs='std::vector<migraphx::module*>'),
|
||||
fname='add_instruction',
|
||||
returns='migraphx::instruction_ref')
|
||||
h.method('add_literal',
|
||||
api.params(shape='const migraphx::shape&', buffer='const char*'),
|
||||
returns='migraphx::instruction_ref')
|
||||
h.method('add_parameter',
|
||||
api.params(name='const char*', shape='const migraphx::shape&'),
|
||||
returns='migraphx::instruction_ref')
|
||||
h.method('add_return',
|
||||
api.params(args='std::vector<migraphx::instruction_ref>'),
|
||||
returns='migraphx::instruction_ref')
|
||||
h.method('add_allocation',
|
||||
api.params(s='const migraphx::shape&'),
|
||||
invoke='migraphx::add_allocation($@)',
|
||||
returns='migraphx::instruction_ref')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def program(h):
|
||||
h.constructor('create')
|
||||
h.method('get_main_module', returns='migraphx::module*')
|
||||
h.method('create_module',
|
||||
api.params(name='const char*'),
|
||||
returns='migraphx::module*')
|
||||
h.method(
|
||||
'compile',
|
||||
api.params(target='migraphx::target',
|
||||
options='migraphx::compile_options'))
|
||||
h.method('get_parameter_shapes',
|
||||
returns='std::unordered_map<std::string, migraphx::shape>')
|
||||
h.method('get_output_shapes',
|
||||
invoke='migraphx::get_output_shapes($@)',
|
||||
returns='std::vector<migraphx::shape>')
|
||||
h.method('print', invoke='migraphx::print_program($@)', const=True)
|
||||
h.method('sort')
|
||||
h.method('run',
|
||||
api.params(
|
||||
params='std::unordered_map<std::string, migraphx::argument>'),
|
||||
invoke='migraphx::run($@)',
|
||||
returns='std::vector<migraphx::argument>')
|
||||
h.method('run_async',
|
||||
api.params(
|
||||
params='std::unordered_map<std::string, migraphx::argument>',
|
||||
s='void*',
|
||||
name='const char *'),
|
||||
invoke='migraphx::run_async($@)',
|
||||
returns='std::vector<migraphx::argument>')
|
||||
h.method('equal',
|
||||
api.params(x='const migraphx::program&'),
|
||||
invoke='migraphx::equal($@)',
|
||||
returns='bool',
|
||||
const=True)
|
||||
h.method('experimental_get_context',
|
||||
invoke='migraphx::get_context($@)',
|
||||
const=True,
|
||||
returns='migraphx::context')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def operation(h):
|
||||
h.constructor('create',
|
||||
api.params(name='const char*',
|
||||
attributes='const char*',
|
||||
vlist='...'),
|
||||
fname='migraphx::create_op')
|
||||
h.method('name', returns='std::string')
|
||||
|
||||
|
||||
api.add_function('migraphx_load',
|
||||
api.params(name='const char*',
|
||||
options='migraphx::file_options'),
|
||||
fname='migraphx::load',
|
||||
returns='migraphx::program')
|
||||
|
||||
api.add_function('migraphx_save',
|
||||
api.params(p='migraphx::program&',
|
||||
name='const char*',
|
||||
options='migraphx::file_options'),
|
||||
fname='migraphx::save')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def onnx_options(h):
|
||||
h.constructor('create')
|
||||
h.method(
|
||||
'set_input_parameter_shape',
|
||||
api.params(name='const char*', dims='std::vector<size_t>'),
|
||||
invoke='migraphx::set_input_parameter_shape($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_dyn_input_parameter_shape',
|
||||
api.params(name='const char*',
|
||||
dims='std::vector<migraphx::shape::dynamic_dimension>'),
|
||||
invoke='migraphx::set_dyn_input_parameter_shape($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_default_dim_value',
|
||||
api.params(value='size_t'),
|
||||
invoke='migraphx::set_default_dim_value($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_default_dyn_dim_value',
|
||||
api.params(dd='const migraphx::shape::dynamic_dimension&'),
|
||||
invoke='migraphx::set_default_dyn_dim_value($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_default_loop_iterations',
|
||||
api.params(value='int64_t'),
|
||||
invoke='migraphx::set_default_loop_iterations($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_limit_loop_iterations',
|
||||
api.params(value='int64_t'),
|
||||
invoke='migraphx::set_limit_loop_iterations($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_external_data_path',
|
||||
api.params(external_data_path='const char*'),
|
||||
invoke='migraphx::set_external_data_path($@)',
|
||||
)
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def file_options(h):
|
||||
h.constructor('create')
|
||||
h.method('set_file_format',
|
||||
api.params(format='const char*'),
|
||||
invoke='migraphx::set_file_format($@)')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def compile_options(h):
|
||||
h.constructor('create')
|
||||
h.method('set_offload_copy',
|
||||
api.params(value='bool'),
|
||||
invoke='migraphx::set_offload_copy($@)')
|
||||
h.method('set_fast_math',
|
||||
api.params(value='bool'),
|
||||
invoke='migraphx::set_fast_math($@)')
|
||||
h.method('set_exhaustive_tune_flag',
|
||||
api.params(value='bool'),
|
||||
invoke='migraphx::set_exhaustive_tune_flag($@)')
|
||||
|
||||
|
||||
api.add_function('migraphx_parse_onnx',
|
||||
api.params(name='const char*',
|
||||
options='migraphx::onnx_options'),
|
||||
fname='migraphx::parse_onnx',
|
||||
returns='migraphx::program')
|
||||
|
||||
api.add_function('migraphx_parse_onnx_buffer',
|
||||
api.params(data='const void*',
|
||||
size='size_t',
|
||||
options='migraphx::onnx_options'),
|
||||
fname='migraphx::parse_onnx_buffer',
|
||||
returns='migraphx::program')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def tf_options(h):
|
||||
h.constructor('create')
|
||||
h.method(
|
||||
'set_nhwc',
|
||||
api.params(is_nhwc='bool'),
|
||||
invoke='migraphx::set_nhwc($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_input_parameter_shape',
|
||||
api.params(name='const char*', dims='std::vector<size_t>'),
|
||||
invoke='migraphx::set_input_parameter_shape($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_default_dim_value',
|
||||
api.params(value='size_t'),
|
||||
invoke='migraphx::set_default_dim_value($@)',
|
||||
)
|
||||
h.method(
|
||||
'set_output_names',
|
||||
api.params(names='std::vector<const char*>'),
|
||||
invoke='migraphx::set_output_names($@)',
|
||||
)
|
||||
|
||||
|
||||
api.add_function('migraphx_parse_tf',
|
||||
api.params(name='const char*',
|
||||
options='migraphx::tf_options'),
|
||||
fname='migraphx::parse_tf',
|
||||
returns='migraphx::program')
|
||||
|
||||
|
||||
@api.handle('migraphx_quantize_op_names', 'std::vector<std::string>')
|
||||
def quantize_op_names(h):
|
||||
h.constructor('create')
|
||||
h.method('add', api.params(name='const char*'), fname='push_back')
|
||||
|
||||
|
||||
api.add_function('migraphx_quantize_fp16_with_op_names',
|
||||
api.params(prog='migraphx::program&',
|
||||
name='std::vector<std::string>&'),
|
||||
fname='migraphx::quantize_fp16_with_op_names')
|
||||
|
||||
api.add_function('migraphx_quantize_fp16',
|
||||
api.params(prog='migraphx::program&'),
|
||||
fname='migraphx::quantize_fp16')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def quantize_int8_options(h):
|
||||
h.constructor('create')
|
||||
h.method(
|
||||
'add_op_name',
|
||||
api.params(name='const char*'),
|
||||
invoke='migraphx::add_op_name($@)',
|
||||
)
|
||||
h.method(
|
||||
'add_calibration_data',
|
||||
api.params(data='std::unordered_map<std::string, migraphx::argument>'),
|
||||
invoke='migraphx::add_calibration_data($@)',
|
||||
)
|
||||
|
||||
|
||||
api.add_function('migraphx_quantize_int8',
|
||||
api.params(prog='migraphx::program&',
|
||||
target='migraphx::target',
|
||||
options='migraphx::quantize_int8_options'),
|
||||
fname='migraphx::quantize_int8_wrap')
|
||||
|
||||
|
||||
@auto_handle()
|
||||
def quantize_fp8_options(h):
|
||||
h.constructor('create')
|
||||
h.method(
|
||||
'add_calibration_data',
|
||||
api.params(data='std::unordered_map<std::string, migraphx::argument>'),
|
||||
invoke='migraphx::add_calibration_data($@)',
|
||||
)
|
||||
|
||||
|
||||
api.add_function('migraphx_quantize_fp8',
|
||||
api.params(prog='migraphx::program&',
|
||||
target='migraphx::target',
|
||||
options='migraphx::quantize_fp8_options'),
|
||||
fname='migraphx::quantize_fp8_wrap')
|
||||
|
||||
|
||||
@auto_handle(ref=True)
|
||||
def context(h):
|
||||
h.method('finish', const=True)
|
||||
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
|
||||
|
||||
|
||||
@api.interface('migraphx_experimental_custom_op',
|
||||
'migraphx::experimental_custom_op')
|
||||
def experimental_custom_op(h):
|
||||
h.constructor('create',
|
||||
api.params(obj_typename='const char*', name='const char*'))
|
||||
h.virtual('compute',
|
||||
api.params(ctx='migraphx::context',
|
||||
output='migraphx::shape',
|
||||
inputs='std::vector<migraphx::argument>'),
|
||||
returns='migraphx::argument')
|
||||
h.virtual('compute_shape',
|
||||
api.params(inputs='std::vector<migraphx::shape>'),
|
||||
returns='migraphx::shape')
|
||||
h.virtual('output_alias',
|
||||
api.params(inputs='std::vector<migraphx::shape>'),
|
||||
returns='std::vector<size_t>')
|
||||
h.virtual('runs_on_offload_target', returns='bool')
|
||||
h.method('register', invoke='migraphx::register_custom_op($@)')
|
||||
60
docker/rocm/migraphx/driver/CMakeLists.txt
Normal file
60
docker/rocm/migraphx/driver/CMakeLists.txt
Normal file
@ -0,0 +1,60 @@
|
||||
#####################################################################################
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#####################################################################################
|
||||
|
||||
add_executable(driver
|
||||
main.cpp
|
||||
verify.cpp
|
||||
passes.cpp
|
||||
mlir.cpp
|
||||
models.cpp
|
||||
perf.cpp
|
||||
marker_roctx.cpp
|
||||
)
|
||||
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
|
||||
if(NOT WIN32)
|
||||
# Copy driver for backwards compatibility (Linux only)
|
||||
add_custom_command(
|
||||
TARGET driver
|
||||
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
|
||||
$<TARGET_FILE:driver>
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
|
||||
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
|
||||
)
|
||||
set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
|
||||
endif()
|
||||
rocm_clang_tidy_check(driver)
|
||||
|
||||
file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output)
|
||||
target_compile_definitions(driver PUBLIC MIGRAPHX_ORT_SHA1="${String_output}")
|
||||
|
||||
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
|
||||
|
||||
if(MIGRAPHX_ENABLE_PYTHON)
|
||||
target_link_libraries(driver migraphx_py)
|
||||
target_compile_definitions(driver PRIVATE MIGRAPHX_ENABLE_PYTHON)
|
||||
endif()
|
||||
|
||||
rocm_install_targets(
|
||||
TARGETS driver
|
||||
)
|
||||
748
docker/rocm/migraphx/driver/argument_parser.hpp
Normal file
748
docker/rocm/migraphx/driver/argument_parser.hpp
Normal file
@ -0,0 +1,748 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/requires.hpp>
|
||||
#include <migraphx/type_name.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/filesystem.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/rank.hpp>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef MIGRAPHX_USE_CLANG_TIDY
|
||||
#define MIGRAPHX_DRIVER_STATIC
|
||||
#else
|
||||
#define MIGRAPHX_DRIVER_STATIC static
|
||||
#endif
|
||||
|
||||
template <class T>
|
||||
using bare = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T>
|
||||
auto is_container(int, T&& x) -> decltype(x.insert(x.end(), *x.begin()), std::true_type{});
|
||||
|
||||
template <class T>
|
||||
std::false_type is_container(float, T&&);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class T>
|
||||
struct is_container : decltype(detail::is_container(int(0), std::declval<T>()))
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using is_multi_value =
|
||||
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
|
||||
|
||||
enum class color
|
||||
{
|
||||
reset = 0,
|
||||
bold = 1,
|
||||
underlined = 4,
|
||||
fg_red = 31,
|
||||
fg_green = 32,
|
||||
fg_yellow = 33,
|
||||
fg_blue = 34,
|
||||
fg_default = 39,
|
||||
bg_red = 41,
|
||||
bg_green = 42,
|
||||
bg_yellow = 43,
|
||||
bg_blue = 44,
|
||||
bg_default = 49
|
||||
};
|
||||
inline std::ostream& operator<<(std::ostream& os, const color& c)
|
||||
{
|
||||
#ifndef _WIN32
|
||||
static const bool use_color = isatty(STDOUT_FILENO) != 0;
|
||||
if(use_color)
|
||||
return os << "\033[" << static_cast<std::size_t>(c) << "m";
|
||||
#else
|
||||
(void)c;
|
||||
#endif
|
||||
return os;
|
||||
}
|
||||
|
||||
inline std::string colorize(color c, const std::string& s)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << c << s << color::reset;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct type_name
|
||||
{
|
||||
static const std::string& apply() { return migraphx::get_type_name<T>(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_name<std::string>
|
||||
{
|
||||
static const std::string& apply()
|
||||
{
|
||||
static const std::string name = "std::string";
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct type_name<std::vector<T>>
|
||||
{
|
||||
static const std::string& apply()
|
||||
{
|
||||
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct value_parser
|
||||
{
|
||||
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{} and not is_multi_value<T>{})>
|
||||
static T apply(const std::string& x)
|
||||
{
|
||||
// handle whitespace in string
|
||||
if constexpr(std::is_same<T, std::string>{})
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else
|
||||
{
|
||||
T result;
|
||||
std::stringstream ss;
|
||||
ss.str(x);
|
||||
ss >> result;
|
||||
if(ss.fail())
|
||||
throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
template <MIGRAPHX_REQUIRES(std::is_enum<T>{} and not is_multi_value<T>{})>
|
||||
static T apply(const std::string& x)
|
||||
{
|
||||
std::ptrdiff_t i;
|
||||
std::stringstream ss;
|
||||
ss.str(x);
|
||||
ss >> i;
|
||||
if(ss.fail())
|
||||
throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
|
||||
return static_cast<T>(i);
|
||||
}
|
||||
|
||||
template <MIGRAPHX_REQUIRES(is_multi_value<T>{} and not std::is_enum<T>{})>
|
||||
static T apply(const std::string& x)
|
||||
{
|
||||
T result;
|
||||
using value_type = typename T::value_type;
|
||||
result.insert(result.end(), value_parser<value_type>::apply(x));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// version for std::optional object
|
||||
template <class T>
|
||||
struct value_parser<std::optional<T>>
|
||||
{
|
||||
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
|
||||
};
|
||||
|
||||
struct argument_parser
|
||||
{
|
||||
struct argument
|
||||
{
|
||||
using action_function =
|
||||
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
|
||||
using validate_function =
|
||||
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
|
||||
std::vector<std::string> flags;
|
||||
action_function action{};
|
||||
std::string type = "";
|
||||
std::string help = "";
|
||||
std::string metavar = "";
|
||||
std::string default_value = "";
|
||||
std::string group = "";
|
||||
unsigned nargs = 1;
|
||||
bool required = false;
|
||||
std::vector<validate_function> validations{};
|
||||
|
||||
std::string usage(const std::string& flag) const
|
||||
{
|
||||
std::stringstream ss;
|
||||
if(flag.empty())
|
||||
{
|
||||
ss << metavar;
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << flag;
|
||||
if(not type.empty())
|
||||
ss << " [" << type << "]";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
std::string usage() const
|
||||
{
|
||||
if(flags.empty())
|
||||
return usage("");
|
||||
return usage(flags.front());
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
|
||||
std::string as_string_value(const T& x)
|
||||
{
|
||||
return to_string_range(x);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x))
|
||||
{
|
||||
return to_string(x);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::string as_string_value(rank<0>, const T&)
|
||||
{
|
||||
throw std::runtime_error("Can't convert to string");
|
||||
}
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
|
||||
std::string as_string_value(const T& x)
|
||||
{
|
||||
return as_string_value(rank<1>{}, x);
|
||||
}
|
||||
|
||||
template <class T, class... Fs>
|
||||
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
|
||||
{
|
||||
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
|
||||
if(params.empty())
|
||||
throw std::runtime_error("Flag with no value.");
|
||||
if(not is_multi_value<T>{} and params.size() > 1)
|
||||
throw std::runtime_error("Too many arguments passed.");
|
||||
x = value_parser<T>::apply(params.back());
|
||||
return false;
|
||||
}});
|
||||
|
||||
argument& arg = arguments.back();
|
||||
arg.type = type_name<T>::apply();
|
||||
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
|
||||
if(not arg.default_value.empty() and arg.nargs > 0)
|
||||
arg.default_value = as_string_value(x);
|
||||
}
|
||||
|
||||
template <class... Fs>
|
||||
void operator()(std::nullptr_t x, std::vector<std::string> flags, Fs... fs)
|
||||
{
|
||||
arguments.push_back({std::move(flags)});
|
||||
|
||||
argument& arg = arguments.back();
|
||||
arg.type = "";
|
||||
arg.nargs = 0;
|
||||
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto nargs(unsigned n = 1)
|
||||
{
|
||||
return [=](auto&&, auto& arg) { arg.nargs = n; };
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto required()
|
||||
{
|
||||
return [=](auto&&, auto& arg) { arg.required = true; };
|
||||
}
|
||||
|
||||
template <class F>
|
||||
MIGRAPHX_DRIVER_STATIC auto write_action(F f)
|
||||
{
|
||||
return [=](auto& x, auto& arg) {
|
||||
arg.action = [&, f](auto& self, const std::vector<std::string>& params) {
|
||||
f(self, x, params);
|
||||
return false;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
MIGRAPHX_DRIVER_STATIC auto do_action(F f)
|
||||
{
|
||||
return [=](auto&, auto& arg) {
|
||||
arg.nargs = 0;
|
||||
arg.action = [&, f](auto& self, const std::vector<std::string>&) {
|
||||
f(self);
|
||||
return true;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto append()
|
||||
{
|
||||
return write_action([](auto&, auto& x, auto& params) {
|
||||
using type = typename bare<decltype(params)>::value_type;
|
||||
std::transform(params.begin(),
|
||||
params.end(),
|
||||
std::inserter(x, x.end()),
|
||||
[](std::string y) { return value_parser<type>::apply(y); });
|
||||
});
|
||||
}
|
||||
|
||||
template <class F>
|
||||
MIGRAPHX_DRIVER_STATIC auto validate(F f)
|
||||
{
|
||||
return [=](const auto& x, auto& arg) {
|
||||
arg.validations.push_back(
|
||||
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
|
||||
};
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto file_exist()
|
||||
{
|
||||
return validate([](auto&, auto&, const auto& params) {
|
||||
if(params.empty())
|
||||
throw std::runtime_error("No argument passed.");
|
||||
if(not fs::exists(params.back()))
|
||||
throw std::runtime_error("Path does not exist: " + params.back());
|
||||
});
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
|
||||
{
|
||||
return validate([=](auto&, auto&, const auto& params) {
|
||||
auto invalid_param = std::find_if(
|
||||
params.begin(), params.end(), [&](const auto& p) { return names.count(p) == 0; });
|
||||
if(invalid_param != params.end())
|
||||
throw std::runtime_error("Invalid argument: " + *invalid_param +
|
||||
". Valid arguments are {" + to_string_range(names) + "}");
|
||||
});
|
||||
}
|
||||
|
||||
template <class F>
|
||||
argument* find_argument(F f)
|
||||
{
|
||||
auto it = std::find_if(arguments.begin(), arguments.end(), f);
|
||||
if(it == arguments.end())
|
||||
return nullptr;
|
||||
return std::addressof(*it);
|
||||
}
|
||||
template <class F>
|
||||
bool has_argument(F f)
|
||||
{
|
||||
return find_argument(f) != nullptr;
|
||||
}
|
||||
|
||||
template <class F>
|
||||
std::vector<argument*> find_arguments(F f)
|
||||
{
|
||||
std::vector<argument*> result;
|
||||
for(auto& arg : arguments)
|
||||
{
|
||||
if(not f(arg))
|
||||
continue;
|
||||
result.push_back(&arg);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<argument*> get_group_arguments(const std::string& group)
|
||||
{
|
||||
return find_arguments([&](const auto& arg) { return arg.group == group; });
|
||||
}
|
||||
|
||||
std::vector<argument*> get_required_arguments()
|
||||
{
|
||||
return find_arguments([&](const auto& arg) { return arg.required; });
|
||||
}
|
||||
|
||||
template <class SequenceContainer>
|
||||
std::vector<std::string> get_argument_usages(SequenceContainer args)
|
||||
{
|
||||
std::vector<std::string> usage_flags;
|
||||
std::unordered_set<std::string> found_groups;
|
||||
// Remove arguments that belong to a group
|
||||
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
|
||||
if(arg->group.empty())
|
||||
return false;
|
||||
found_groups.insert(arg->group);
|
||||
return true;
|
||||
});
|
||||
args.erase(it, args.end());
|
||||
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
|
||||
std::vector<std::string> either_flags;
|
||||
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
|
||||
return arg->usage();
|
||||
});
|
||||
return "(" + join_strings(either_flags, "|") + ")";
|
||||
});
|
||||
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
|
||||
return usage_flags;
|
||||
}
|
||||
|
||||
auto show_help(const std::string& msg = "")
|
||||
{
|
||||
return do_action([=](auto& self) {
|
||||
argument* input_argument =
|
||||
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
|
||||
auto required_usages = get_argument_usages(get_required_arguments());
|
||||
if(required_usages.empty() and input_argument)
|
||||
required_usages.push_back(input_argument->metavar);
|
||||
required_usages.insert(required_usages.begin(), "<options>");
|
||||
print_usage(required_usages);
|
||||
std::cout << std::endl;
|
||||
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
|
||||
{
|
||||
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
|
||||
std::cout << std::endl;
|
||||
for(auto&& arg : self.arguments)
|
||||
{
|
||||
if(arg.nargs != 0)
|
||||
continue;
|
||||
const int col_align = 35;
|
||||
std::string prefix = " ";
|
||||
int len = 0;
|
||||
std::cout << color::fg_green;
|
||||
for(const std::string& a : arg.flags)
|
||||
{
|
||||
len += prefix.length() + a.length();
|
||||
std::cout << prefix;
|
||||
std::cout << a;
|
||||
prefix = ", ";
|
||||
}
|
||||
std::cout << color::reset;
|
||||
int spaces = col_align - len;
|
||||
if(spaces < 0)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < spaces; i++)
|
||||
std::cout << " ";
|
||||
}
|
||||
std::cout << arg.help << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
|
||||
{
|
||||
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
|
||||
for(auto&& arg : self.arguments)
|
||||
{
|
||||
if(arg.nargs == 0)
|
||||
continue;
|
||||
std::cout << std::endl;
|
||||
std::string prefix = " ";
|
||||
std::cout << color::fg_green;
|
||||
if(arg.flags.empty())
|
||||
{
|
||||
std::cout << prefix;
|
||||
std::cout << arg.metavar;
|
||||
}
|
||||
for(const std::string& a : arg.flags)
|
||||
{
|
||||
std::cout << prefix;
|
||||
std::cout << a;
|
||||
prefix = ", ";
|
||||
}
|
||||
std::cout << color::reset;
|
||||
if(not arg.type.empty())
|
||||
{
|
||||
std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
|
||||
if(not arg.default_value.empty())
|
||||
std::cout << " (Default: " << arg.default_value << ")";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << " " << arg.help << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if(not msg.empty())
|
||||
std::cout << msg << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto help(const std::string& help)
|
||||
{
|
||||
return [=](auto&, auto& arg) { arg.help = help; };
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto metavar(const std::string& metavar)
|
||||
{
|
||||
return [=](auto&, auto& arg) { arg.metavar = metavar; };
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto type(const std::string& type)
|
||||
{
|
||||
return [=](auto&, auto& arg) { arg.type = type; };
|
||||
}
|
||||
|
||||
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
|
||||
{
|
||||
return [=](auto&, auto& arg) { arg.group = group; };
|
||||
}
|
||||
|
||||
template <class T>
|
||||
MIGRAPHX_DRIVER_STATIC auto set_value(T value)
|
||||
{
|
||||
return [=](auto& x, auto& arg) {
|
||||
arg.nargs = 0;
|
||||
arg.type = "";
|
||||
arg.action = [&, value](auto&, const std::vector<std::string>&) {
|
||||
x = value;
|
||||
return false;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void set_exe_name_to(T& x)
|
||||
{
|
||||
actions.push_back([&](const auto& self) { x = self.exe_name; });
|
||||
}
|
||||
|
||||
void print_try_help()
|
||||
{
|
||||
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
|
||||
{
|
||||
std::cout << std::endl;
|
||||
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
|
||||
<< "'" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void print_usage(const std::vector<std::string>& flags) const
|
||||
{
|
||||
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
|
||||
std::cout << " " << exe_name << " ";
|
||||
std::cout << join_strings(flags, " ") << std::endl;
|
||||
}
|
||||
|
||||
auto spellcheck(const std::vector<std::string>& inputs)
|
||||
{
|
||||
struct result_t
|
||||
{
|
||||
const argument* arg = nullptr;
|
||||
std::string correct = "";
|
||||
std::string incorrect = "";
|
||||
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
|
||||
};
|
||||
result_t result;
|
||||
for(const auto& input : inputs)
|
||||
{
|
||||
if(input.empty())
|
||||
continue;
|
||||
if(input[0] != '-')
|
||||
continue;
|
||||
for(const auto& arg : arguments)
|
||||
{
|
||||
for(const auto& flag : arg.flags)
|
||||
{
|
||||
if(flag.empty())
|
||||
continue;
|
||||
if(flag[0] != '-')
|
||||
continue;
|
||||
std::ptrdiff_t d = levenshtein_distance(flag, input);
|
||||
if(d < result.distance)
|
||||
result = result_t{&arg, flag, input, d};
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool
|
||||
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
|
||||
{
|
||||
std::string msg = "";
|
||||
try
|
||||
{
|
||||
for(const auto& v : arg.validations)
|
||||
v(*this, inputs);
|
||||
return arg.action(*this, inputs);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
msg = e.what();
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
msg = "unknown exception";
|
||||
}
|
||||
std::cout << color::fg_red << color::bold << "error: " << color::reset;
|
||||
auto sc = spellcheck(inputs);
|
||||
if(sc.distance < 5)
|
||||
{
|
||||
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
|
||||
<< "'"
|
||||
<< " which wasn't expected, or isn't valid in this context" << std::endl;
|
||||
std::cout << " "
|
||||
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
|
||||
<< std::endl;
|
||||
std::cout << std::endl;
|
||||
print_usage({sc.arg->usage(sc.correct)});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& flag_name = flag.empty() ? arg.metavar : flag;
|
||||
std::cout << "Invalid input to '" << color::fg_yellow;
|
||||
std::cout << arg.usage(flag_name);
|
||||
std::cout << color::reset << "'" << std::endl;
|
||||
std::cout << " " << msg << std::endl;
|
||||
std::cout << std::endl;
|
||||
print_usage({arg.usage()});
|
||||
}
|
||||
std::cout << std::endl;
|
||||
print_try_help();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool parse(std::vector<std::string> args)
|
||||
{
|
||||
std::unordered_map<std::string, unsigned> keywords;
|
||||
for(auto&& arg : arguments)
|
||||
{
|
||||
for(auto&& flag : arg.flags)
|
||||
keywords[flag] = arg.nargs + 1;
|
||||
}
|
||||
auto arg_map =
|
||||
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
|
||||
std::list<const argument*> missing_arguments;
|
||||
std::unordered_set<std::string> groups_used;
|
||||
for(auto&& arg : arguments)
|
||||
{
|
||||
bool used = false;
|
||||
auto flags = arg.flags;
|
||||
if(flags.empty())
|
||||
flags = {""};
|
||||
for(auto&& flag : flags)
|
||||
{
|
||||
if(arg_map.count(flag) > 0)
|
||||
{
|
||||
if(run_action(arg, flag, arg_map[flag]))
|
||||
return true;
|
||||
used = true;
|
||||
}
|
||||
}
|
||||
if(used and not arg.group.empty())
|
||||
groups_used.insert(arg.group);
|
||||
if(arg.required and not used)
|
||||
missing_arguments.push_back(&arg);
|
||||
}
|
||||
// Remove arguments from a group that is being used
|
||||
missing_arguments.remove_if(
|
||||
[&](const argument* arg) { return groups_used.count(arg->group); });
|
||||
if(not missing_arguments.empty())
|
||||
{
|
||||
std::cout << color::fg_red << color::bold << "error: " << color::reset;
|
||||
std::cout << "The following required arguments were not provided:" << std::endl;
|
||||
std::cout << " " << color::fg_red
|
||||
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
|
||||
<< color::reset << std::endl;
|
||||
std::cout << std::endl;
|
||||
auto required_usages = get_argument_usages(get_required_arguments());
|
||||
print_usage(required_usages);
|
||||
print_try_help();
|
||||
return true;
|
||||
}
|
||||
for(auto&& action : actions)
|
||||
action(*this);
|
||||
return false;
|
||||
}
|
||||
|
||||
void set_exe_name(const std::string& s) { exe_name = s; }
|
||||
|
||||
const std::string& get_exe_name() const { return exe_name; }
|
||||
|
||||
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
template <class IsKeyword>
|
||||
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
|
||||
{
|
||||
string_map result;
|
||||
|
||||
std::string flag;
|
||||
bool clear = false;
|
||||
for(auto&& x : as)
|
||||
{
|
||||
auto k = is_keyword(x);
|
||||
if(k > 0)
|
||||
{
|
||||
flag = x;
|
||||
result[flag]; // Ensure the flag exists
|
||||
if(k == 1)
|
||||
flag = "";
|
||||
else if(k == 2)
|
||||
clear = true;
|
||||
else
|
||||
clear = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
result[flag].push_back(x);
|
||||
if(clear)
|
||||
flag = "";
|
||||
clear = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::list<argument> arguments;
|
||||
std::string exe_name = "";
|
||||
std::vector<std::function<void(argument_parser&)>> actions;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
116
docker/rocm/migraphx/driver/command.hpp
Normal file
116
docker/rocm/migraphx/driver/command.hpp
Normal file
@ -0,0 +1,116 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
|
||||
|
||||
#include "argument_parser.hpp"
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/type_name.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
inline auto& get_commands()
|
||||
{
|
||||
// NOLINTNEXTLINE
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
|
||||
m;
|
||||
return m;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::string compute_command_name()
|
||||
{
|
||||
static const std::string& tname = get_type_name<T>();
|
||||
auto name = tname.substr(tname.rfind("::") + 2);
|
||||
if(ends_with(name, "_command"))
|
||||
name = name.substr(0, name.size() - 8);
|
||||
if(ends_with(name, "_cmd"))
|
||||
name = name.substr(0, name.size() - 4);
|
||||
return name;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
const std::string& command_name()
|
||||
{
|
||||
static const std::string& name = compute_command_name<T>();
|
||||
return name;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
|
||||
{
|
||||
T x;
|
||||
argument_parser ap;
|
||||
ap.set_exe_name(exe_name + " " + command_name<T>());
|
||||
if(add_help)
|
||||
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
|
||||
x.parse(ap);
|
||||
if(ap.parse(std::move(args)))
|
||||
return;
|
||||
x.run();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
int auto_register_command()
|
||||
{
|
||||
auto& m = get_commands();
|
||||
m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
|
||||
run_command<T>(exe_name, args, true);
|
||||
};
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct command
|
||||
{
|
||||
static const int static_register;
|
||||
// This typedef ensures that the static member will be instantiated if
|
||||
// the class itself is instantiated
|
||||
using static_register_type =
|
||||
std::integral_constant<decltype(&static_register), &static_register>;
|
||||
};
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wglobal-constructors"
|
||||
#endif
|
||||
|
||||
template <class T>
|
||||
const int command<T>::static_register = auto_register_command<T>(); // NOLINT
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
969
docker/rocm/migraphx/driver/main.cpp
Normal file
969
docker/rocm/migraphx/driver/main.cpp
Normal file
@ -0,0 +1,969 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "verify.hpp"
|
||||
#include "verify_options.hpp"
|
||||
#include "argument_parser.hpp"
|
||||
#include "command.hpp"
|
||||
#include "mlir.hpp"
|
||||
#include "precision.hpp"
|
||||
#include "passes.hpp"
|
||||
#include "perf.hpp"
|
||||
#include "models.hpp"
|
||||
#include "marker_roctx.hpp"
|
||||
|
||||
#include <migraphx/tf.hpp>
|
||||
#include <migraphx/onnx.hpp>
|
||||
#ifdef MIGRAPHX_ENABLE_PYTHON
|
||||
#include <migraphx/py.hpp>
|
||||
#endif
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/convert_to_json.hpp>
|
||||
#include <migraphx/load_save.hpp>
|
||||
#include <migraphx/json.hpp>
|
||||
#include <migraphx/version.h>
|
||||
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/eliminate_identity.hpp>
|
||||
#include <migraphx/eliminate_pad.hpp>
|
||||
#include <migraphx/generate.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/propagate_constant.hpp>
|
||||
#include <migraphx/quantization.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/simplify_algebra.hpp>
|
||||
#include <migraphx/simplify_reshapes.hpp>
|
||||
#include <migraphx/register_target.hpp>
|
||||
|
||||
#include <migraphx/netron_output.hpp>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
inline std::string get_version()
|
||||
{
|
||||
return "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + "." +
|
||||
std::to_string(MIGRAPHX_VERSION_MINOR) + "." + std::to_string(MIGRAPHX_VERSION_PATCH) +
|
||||
"." MIGRAPHX_VERSION_TWEAK;
|
||||
}
|
||||
|
||||
struct loader
|
||||
{
|
||||
std::string file;
|
||||
std::string file_type;
|
||||
unsigned batch = 1;
|
||||
bool is_nhwc = true;
|
||||
bool is_test = false;
|
||||
unsigned trim = 0;
|
||||
bool optimize = false;
|
||||
bool mlir = false;
|
||||
bool skip_unknown_operators = false;
|
||||
bool brief = false;
|
||||
std::string output_type;
|
||||
std::string output;
|
||||
std::string default_dyn_dim;
|
||||
std::vector<std::string> param_dims;
|
||||
std::vector<std::string> dim_params;
|
||||
std::vector<std::string> dyn_param_dims;
|
||||
std::vector<std::string> output_names;
|
||||
std::vector<std::string> passes;
|
||||
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
|
||||
ap(is_test,
|
||||
{"--test"},
|
||||
ap.help("Run a single GEMM to test MIGraphX"),
|
||||
ap.set_value(true),
|
||||
ap.group("input"));
|
||||
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
|
||||
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
|
||||
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
|
||||
ap(file_type, {"--migraphx-json"}, ap.help("Load as MIGraphX JSON"), ap.set_value("json"));
|
||||
ap(batch,
|
||||
{"--batch"},
|
||||
ap.help("For a static model, sets default_dim_value size (commonly batch size). For a "
|
||||
"dynamic batch model, sets the batch "
|
||||
"size at runtime."));
|
||||
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
|
||||
ap(skip_unknown_operators,
|
||||
{"--skip-unknown-operators"},
|
||||
ap.help("Skip unknown operators when parsing and continue to parse."),
|
||||
ap.set_value(true));
|
||||
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
|
||||
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
|
||||
ap(param_dims,
|
||||
{"--input-dim"},
|
||||
ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
|
||||
ap.append(),
|
||||
ap.nargs(2));
|
||||
ap(dim_params,
|
||||
{"--dim-param"},
|
||||
ap.help("Symbolic parameter dimension name (fixed / dynamic) - "
|
||||
"(fixed format): \"@dim_param_name\" \"x\" / "
|
||||
"(dynamic format): \"@dim_param_name\" \"{min:x, max:y, optimals:[o1,o2]}\""),
|
||||
ap.append(),
|
||||
ap.nargs(2));
|
||||
ap(dyn_param_dims,
|
||||
{"--dyn-input-dim"},
|
||||
ap.help("Dynamic dimensions of a parameter (format: \"@name_1\" \"[{min:x, max:y, "
|
||||
"optimals:[o1,o2,...]}, dim2,dim3, ...]\", \"@name_2\", ... You can supply a "
|
||||
"single integer value for a dimension to specify it as fixed."),
|
||||
ap.append(),
|
||||
ap.nargs(2));
|
||||
ap(default_dyn_dim,
|
||||
{"--default-dyn-dim"},
|
||||
ap.help("Default dynamic dimension (format: \"{min:x, max:y, optimals:[o1,o2]}\")."));
|
||||
ap(output_names,
|
||||
{"--output-names"},
|
||||
ap.help("Names of node output (format: \"name_1 name_2 name_n\")"),
|
||||
ap.append(),
|
||||
ap.nargs(2));
|
||||
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
|
||||
ap(mlir, {"--mlir"}, ap.help("Offload everything to mlir"), ap.set_value(true));
|
||||
ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
|
||||
ap(output_type,
|
||||
{"--graphviz", "-g"},
|
||||
ap.help("Print out a graphviz representation."),
|
||||
ap.set_value("graphviz"));
|
||||
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
|
||||
ap(output_type,
|
||||
{"--cpp"},
|
||||
ap.help("Print out the program as C++ program."),
|
||||
ap.set_value("cpp"));
|
||||
ap(output_type,
|
||||
{"--python", "--py"},
|
||||
ap.help("Print out the program as python program."),
|
||||
ap.set_value("py"));
|
||||
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
|
||||
ap(output_type,
|
||||
{"--text"},
|
||||
ap.help("Print out program in text format."),
|
||||
ap.set_value("text"));
|
||||
ap(output_type,
|
||||
{"--binary"},
|
||||
ap.help("Print out program in binary format."),
|
||||
ap.set_value("binary"));
|
||||
ap(output_type,
|
||||
{"--netron"},
|
||||
ap.help("Print out program as Netron readable json."),
|
||||
ap.set_value("netron"));
|
||||
ap(output, {"--output", "-o"}, ap.help("Output to file."));
|
||||
}
|
||||
|
||||
static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
|
||||
{
|
||||
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
|
||||
std::string name = "";
|
||||
for(auto&& x : param_dims_info)
|
||||
{
|
||||
if(x[0] == '@')
|
||||
{
|
||||
name = x.substr(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
map_input_dims[name].push_back(value_parser<std::size_t>::apply(x));
|
||||
}
|
||||
}
|
||||
|
||||
return map_input_dims;
|
||||
}
|
||||
|
||||
static auto parse_dyn_dims_json(const std::string& dd_json)
|
||||
{
|
||||
// expecting a json string like "[{min:1,max:64,optimals:[1,2,4,8]},3,224,224]"
|
||||
auto v = from_json_string(convert_to_json(dd_json));
|
||||
std::vector<migraphx::shape::dynamic_dimension> dyn_dims;
|
||||
std::transform(v.begin(), v.end(), std::back_inserter(dyn_dims), [&](auto x) {
|
||||
if(x.is_object())
|
||||
return from_value<migraphx::shape::dynamic_dimension>(x);
|
||||
auto d = x.template to<std::size_t>();
|
||||
return migraphx::shape::dynamic_dimension{d, d};
|
||||
});
|
||||
return dyn_dims;
|
||||
}
|
||||
|
||||
static auto parse_dyn_dims_map(const std::vector<std::string>& param_dyn_dims)
|
||||
{
|
||||
// expecting vector of strings formatted like
|
||||
// {"@param_name_0", "dd_json_0", "@param_name_1", "dd_json_1", ...}
|
||||
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
|
||||
std::string name = "";
|
||||
for(auto&& x : param_dyn_dims)
|
||||
{
|
||||
if(x[0] == '@')
|
||||
{
|
||||
name = x.substr(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
map_dyn_input_dims[name] = parse_dyn_dims_json(x);
|
||||
}
|
||||
}
|
||||
return map_dyn_input_dims;
|
||||
}
|
||||
|
||||
static auto parse_dim_params(const std::vector<std::string>& dim_params_info)
|
||||
{
|
||||
std::unordered_map<std::string, shape::dynamic_dimension> map_dim_params;
|
||||
std::string name = "";
|
||||
for(auto&& x : dim_params_info)
|
||||
{
|
||||
if(x[0] == '@')
|
||||
{
|
||||
name = x.substr(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(std::all_of(x.begin(), x.end(), [](char ch) {
|
||||
return std::isdigit(static_cast<unsigned char>(ch));
|
||||
}))
|
||||
map_dim_params[name] = {std::stoul(x), std::stoul(x)};
|
||||
else
|
||||
{
|
||||
auto dyn_dim = parse_dyn_dims_json(x);
|
||||
if(dyn_dim.size() != 1)
|
||||
MIGRAPHX_THROW("dim_param must only specifiy one dimension");
|
||||
map_dim_params[name] = dyn_dim.front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map_dim_params;
|
||||
}
|
||||
|
||||
static auto parse_output_names(const std::vector<std::string>& output_names_info)
|
||||
{
|
||||
std::vector<std::string> output_node_names;
|
||||
std::transform(output_names_info.begin(),
|
||||
output_names_info.end(),
|
||||
std::back_inserter(output_node_names),
|
||||
[&](auto x) { return value_parser<std::string>::apply(x); });
|
||||
|
||||
return output_node_names;
|
||||
}
|
||||
|
||||
tf_options get_tf_options() const
|
||||
{
|
||||
auto map_input_dims = parse_param_dims(param_dims);
|
||||
auto output_node_names = parse_output_names(output_names);
|
||||
tf_options options;
|
||||
options.is_nhwc = is_nhwc;
|
||||
options.batch_size = batch;
|
||||
options.map_input_dims = map_input_dims;
|
||||
options.output_node_names = output_node_names;
|
||||
return options;
|
||||
}
|
||||
|
||||
onnx_options get_onnx_options() const
|
||||
{
|
||||
auto map_input_dims = parse_param_dims(param_dims);
|
||||
auto map_dyn_input_dims = parse_dyn_dims_map(dyn_param_dims);
|
||||
auto map_dim_params = parse_dim_params(dim_params);
|
||||
onnx_options options;
|
||||
if(default_dyn_dim.empty())
|
||||
{
|
||||
options.default_dim_value = batch;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v = from_json_string(convert_to_json(default_dyn_dim));
|
||||
options.default_dyn_dim_value = from_value<migraphx::shape::dynamic_dimension>(v);
|
||||
}
|
||||
options.skip_unknown_operators = skip_unknown_operators;
|
||||
options.print_program_on_error = true;
|
||||
options.map_input_dims = map_input_dims;
|
||||
options.map_dyn_input_dims = map_dyn_input_dims;
|
||||
options.dim_params = map_dim_params;
|
||||
return options;
|
||||
}
|
||||
|
||||
static std::string get_file_type(const std::string& file)
|
||||
{
|
||||
if(ends_with(file, ".onnx"))
|
||||
return "onnx";
|
||||
else if(ends_with(file, ".pb"))
|
||||
return "tf";
|
||||
else if(ends_with(file, ".json"))
|
||||
return "json";
|
||||
else if(ends_with(file, ".py"))
|
||||
return "py";
|
||||
else
|
||||
return "migraphx";
|
||||
}
|
||||
|
||||
program load()
|
||||
{
|
||||
program p;
|
||||
if(is_test)
|
||||
{
|
||||
p = test_gemm();
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file_type.empty())
|
||||
{
|
||||
file_type = get_file_type(file);
|
||||
}
|
||||
std::cout << "Reading: " << file << std::endl;
|
||||
if(file_type == "onnx")
|
||||
{
|
||||
p = parse_onnx(file, get_onnx_options());
|
||||
}
|
||||
else if(file_type == "tf")
|
||||
{
|
||||
p = parse_tf(file, get_tf_options());
|
||||
}
|
||||
else if(file_type == "json")
|
||||
{
|
||||
file_options options;
|
||||
options.format = "json";
|
||||
p = migraphx::load(file, options);
|
||||
}
|
||||
#ifdef MIGRAPHX_ENABLE_PYTHON
|
||||
else if(file_type == "py")
|
||||
{
|
||||
p = migraphx::load_py(file);
|
||||
}
|
||||
#endif
|
||||
else if(file_type == "migraphx")
|
||||
{
|
||||
p = migraphx::load(file);
|
||||
}
|
||||
}
|
||||
if(trim > 0)
|
||||
{
|
||||
auto* mm = p.get_main_module();
|
||||
auto last = std::prev(mm->end(), trim);
|
||||
mm->remove_instructions(last, mm->end());
|
||||
}
|
||||
// Remove unused variable when exporting to cpp
|
||||
if(output_type == "cpp")
|
||||
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
|
||||
if(optimize)
|
||||
{
|
||||
migraphx::run_passes(*p.get_main_module(),
|
||||
{
|
||||
migraphx::eliminate_identity{},
|
||||
migraphx::dead_code_elimination{},
|
||||
migraphx::simplify_algebra{},
|
||||
migraphx::dead_code_elimination{},
|
||||
migraphx::simplify_reshapes{},
|
||||
migraphx::dead_code_elimination{},
|
||||
migraphx::propagate_constant{},
|
||||
migraphx::dead_code_elimination{},
|
||||
migraphx::eliminate_pad{},
|
||||
migraphx::dead_code_elimination{},
|
||||
});
|
||||
}
|
||||
if(not passes.empty())
|
||||
migraphx::run_passes(p, get_passes(passes));
|
||||
if(mlir)
|
||||
offload_to_mlir(p);
|
||||
return p;
|
||||
}
|
||||
|
||||
static void write(std::ostream& os, const std::vector<char>& buffer)
|
||||
{
|
||||
os.write(buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
void save(const program& p) const
|
||||
{
|
||||
auto* os = &std::cout;
|
||||
std::ofstream fs;
|
||||
if(not output.empty())
|
||||
{
|
||||
fs.open(output, std::ios::binary);
|
||||
os = &fs;
|
||||
}
|
||||
|
||||
std::string type = output_type;
|
||||
if(type.empty())
|
||||
{
|
||||
if(output.empty())
|
||||
type = "text";
|
||||
else
|
||||
type = "binary";
|
||||
}
|
||||
|
||||
if(type == "py")
|
||||
p.print_py(*os);
|
||||
else if(type == "cpp")
|
||||
p.print_cpp(*os);
|
||||
else if(type == "graphviz")
|
||||
p.print_graph(*os, brief);
|
||||
else if(type == "text")
|
||||
*os << p << std::endl;
|
||||
else if(type == "json")
|
||||
*os << to_json_string(p.to_value()) << std::endl;
|
||||
else if(type == "binary")
|
||||
write(*os, save_buffer(p));
|
||||
else if(type == "netron")
|
||||
*os << make_netron_output(p) << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
struct program_params
|
||||
{
|
||||
std::vector<std::string> fill0{};
|
||||
std::vector<std::string> fill1{};
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2));
|
||||
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2));
|
||||
}
|
||||
|
||||
auto generate(const program& p, const target& t, bool offload, unsigned batch)
|
||||
{
|
||||
parameter_map m;
|
||||
auto param_shapes = p.get_parameter_shapes();
|
||||
std::unordered_map<std::string, shape> static_param_shapes;
|
||||
std::transform(
|
||||
param_shapes.cbegin(),
|
||||
param_shapes.cend(),
|
||||
std::inserter(static_param_shapes, static_param_shapes.end()),
|
||||
[&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); });
|
||||
for(auto&& s : fill0)
|
||||
m[s] = fill_argument(static_param_shapes.at(s), 0);
|
||||
for(auto&& s : fill1)
|
||||
m[s] = fill_argument(static_param_shapes.at(s), 1);
|
||||
fill_param_map(m, static_param_shapes, t, offload);
|
||||
return m;
|
||||
}
|
||||
};
|
||||
|
||||
struct compiler_target
|
||||
{
|
||||
#ifdef HAVE_GPU
|
||||
std::string target_name = "gpu";
|
||||
#elif defined(HAVE_CPU)
|
||||
std::string target_name = "cpu";
|
||||
#elif defined(HAVE_FPGA)
|
||||
std::string target_name = "fpga";
|
||||
#else
|
||||
std::string target_name = "ref";
|
||||
#endif
|
||||
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(target_name, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value("gpu"));
|
||||
ap(target_name, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value("cpu"));
|
||||
ap(target_name,
|
||||
{"--ref"},
|
||||
ap.help("Compile on the reference implementation"),
|
||||
ap.set_value("ref"));
|
||||
}
|
||||
|
||||
target get_target() const { return make_target(target_name); }
|
||||
};
|
||||
|
||||
struct compiler
|
||||
{
|
||||
loader l;
|
||||
program_params parameters;
|
||||
compiler_target ct;
|
||||
compile_options co;
|
||||
bool to_fp16 = false;
|
||||
bool to_bf16 = false;
|
||||
bool to_fp8 = false;
|
||||
bool to_int8 = false;
|
||||
bool to_int4 = false;
|
||||
|
||||
std::vector<std::string> fill0;
|
||||
std::vector<std::string> fill1;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
l.parse(ap);
|
||||
parameters.parse(ap);
|
||||
ct.parse(ap);
|
||||
ap(co.offload_copy,
|
||||
{"--enable-offload-copy"},
|
||||
ap.help("Enable implicit offload copying"),
|
||||
ap.set_value(true));
|
||||
ap(co.fast_math,
|
||||
{"--disable-fast-math"},
|
||||
ap.help("Disable fast math optimization"),
|
||||
ap.set_value(false));
|
||||
ap(co.exhaustive_tune,
|
||||
{"--exhaustive-tune"},
|
||||
ap.help("Exhastively search for best tuning parameters for kernels"),
|
||||
ap.set_value(true));
|
||||
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
|
||||
ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true));
|
||||
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
|
||||
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
|
||||
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
|
||||
}
|
||||
|
||||
auto params(const program& p)
|
||||
{
|
||||
return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
|
||||
}
|
||||
|
||||
auto host_params(const program& p)
|
||||
{
|
||||
return parameters.generate(p, ct.get_target(), true, l.batch);
|
||||
}
|
||||
|
||||
program compile()
|
||||
{
|
||||
auto p = l.load();
|
||||
// Dont compile if its already been compiled
|
||||
|
||||
if(p.is_compiled())
|
||||
{
|
||||
if(ct.target_name == "gpu")
|
||||
{
|
||||
if(is_offload_copy_set(p) and not co.offload_copy)
|
||||
{
|
||||
std::cout
|
||||
<< "[WARNING]: MIGraphX program was likely compiled with offload_copy "
|
||||
"set, Try "
|
||||
"passing "
|
||||
"`--enable-offload-copy` if program run fails.\n";
|
||||
}
|
||||
else if(not is_offload_copy_set(p) and co.offload_copy)
|
||||
{
|
||||
std::cout << "[WARNING]: MIGraphX program was likely compiled without "
|
||||
"offload_copy set, Try "
|
||||
"removing "
|
||||
"`--enable-offload-copy` if program run "
|
||||
"fails.\n";
|
||||
}
|
||||
}
|
||||
|
||||
return p;
|
||||
}
|
||||
auto t = ct.get_target();
|
||||
if(to_fp16)
|
||||
{
|
||||
quantize_fp16(p);
|
||||
}
|
||||
if(to_bf16)
|
||||
{
|
||||
quantize_bf16(p);
|
||||
}
|
||||
if(to_int8)
|
||||
{
|
||||
quantize_int8(p, t, {host_params(p)});
|
||||
}
|
||||
if(to_fp8)
|
||||
{
|
||||
quantize_fp8(p, t, {host_params(p)});
|
||||
}
|
||||
if(to_int4)
|
||||
{
|
||||
quantize_int4_weights(p);
|
||||
}
|
||||
p.compile(t, co);
|
||||
l.save(p);
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
struct read : command<read>
|
||||
{
|
||||
loader l;
|
||||
void parse(argument_parser& ap) { l.parse(ap); }
|
||||
|
||||
void run()
|
||||
{
|
||||
auto p = l.load();
|
||||
l.save(p);
|
||||
}
|
||||
};
|
||||
|
||||
struct params : command<params>
|
||||
{
|
||||
loader l;
|
||||
void parse(argument_parser& ap) { l.parse(ap); }
|
||||
|
||||
void run()
|
||||
{
|
||||
auto p = l.load();
|
||||
for(auto&& param : p.get_parameter_shapes())
|
||||
std::cout << param.first << ": " << param.second << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
struct verify : command<verify>
|
||||
{
|
||||
compiler c;
|
||||
std::optional<double> rms_tol;
|
||||
std::optional<double> atol;
|
||||
std::optional<double> rtol;
|
||||
bool per_instruction = false;
|
||||
bool reduce = false;
|
||||
bool bisect = false;
|
||||
verify_options vo;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
c.parse(ap);
|
||||
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
|
||||
ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
|
||||
ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
|
||||
ap(per_instruction,
|
||||
{"-i", "--per-instruction"},
|
||||
ap.help("Verify each instruction"),
|
||||
ap.set_value(true));
|
||||
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
|
||||
ap(bisect, {"-b", "--bisect"}, ap.help("Bisect program and verify"), ap.set_value(true));
|
||||
ap(vo.ref_use_double,
|
||||
{"--ref-use-double"},
|
||||
ap.help("Convert floating point values to double on ref"),
|
||||
ap.set_value(true));
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
auto p = c.l.load();
|
||||
c.l.save(p);
|
||||
std::cout << p << std::endl;
|
||||
|
||||
auto t = c.ct.get_target();
|
||||
auto m = c.parameters.generate(p, t, true, c.l.batch);
|
||||
|
||||
if(c.to_fp16)
|
||||
{
|
||||
vo.quantize = precision::fp16;
|
||||
}
|
||||
if(c.to_bf16)
|
||||
{
|
||||
vo.quantize = precision::bf16;
|
||||
}
|
||||
if(c.to_int8)
|
||||
{
|
||||
vo.quantize = precision::int8;
|
||||
}
|
||||
|
||||
auto tols = get_tolerances(p, vo, rms_tol, atol, rtol);
|
||||
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
|
||||
std::cout << "atol: " << tols.atol << std::endl;
|
||||
std::cout << "rtol: " << tols.rtol << std::endl;
|
||||
|
||||
if(per_instruction)
|
||||
{
|
||||
verify_instructions(p, t, c.co, vo, tols);
|
||||
}
|
||||
else if(reduce)
|
||||
{
|
||||
verify_reduced_program(p, t, c.co, vo, m, tols);
|
||||
}
|
||||
else if(bisect)
|
||||
{
|
||||
verify_bisected_program(p, t, c.co, vo, m, tols);
|
||||
}
|
||||
else
|
||||
{
|
||||
verify_program(c.l.file, p, t, c.co, vo, m, tols);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct compile : command<compile>
|
||||
{
|
||||
compiler c;
|
||||
void parse(argument_parser& ap) { c.parse(ap); }
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << "Compiling ... " << std::endl;
|
||||
c.compile();
|
||||
}
|
||||
};
|
||||
|
||||
struct run_cmd : command<run_cmd>
|
||||
{
|
||||
compiler c;
|
||||
void parse(argument_parser& ap) { c.parse(ap); }
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << "Compiling ... " << std::endl;
|
||||
auto p = c.compile();
|
||||
std::cout << "Allocating params ... " << std::endl;
|
||||
auto m = c.params(p);
|
||||
p.eval(m);
|
||||
std::cout << p << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
struct time_cmd : command<time_cmd>
|
||||
{
|
||||
compiler c;
|
||||
unsigned n = 100;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run."));
|
||||
c.parse(ap);
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << "Compiling ... " << std::endl;
|
||||
auto p = c.compile();
|
||||
std::cout << "Allocating params ... " << std::endl;
|
||||
auto m = c.params(p);
|
||||
std::cout << "Running ... " << std::endl;
|
||||
double t = time_run(p, m, n);
|
||||
std::cout << "Total time: " << t << "ms" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
struct perf : command<perf>
|
||||
{
|
||||
compiler c;
|
||||
unsigned n = 100;
|
||||
bool detailed = false;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
c.parse(ap);
|
||||
ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run for perf report"));
|
||||
ap(detailed,
|
||||
{"--detailed", "-d"},
|
||||
ap.help("Show a more detailed summary report"),
|
||||
ap.set_value(true));
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << "Compiling ... " << std::endl;
|
||||
auto p = c.compile();
|
||||
std::cout << "Allocating params ... " << std::endl;
|
||||
auto m = c.params(p);
|
||||
std::cout << "Running performance report ... " << std::endl;
|
||||
p.perf_report(std::cout, n, m, c.l.batch, detailed);
|
||||
}
|
||||
};
|
||||
|
||||
struct roctx : command<roctx>
|
||||
{
|
||||
compiler c;
|
||||
void parse(argument_parser& ap) { c.parse(ap); }
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << "Compiling ... " << std::endl;
|
||||
auto p = c.compile();
|
||||
std::cout << "Allocating params ... " << std::endl;
|
||||
auto m = c.params(p);
|
||||
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
|
||||
auto rtx = create_marker_roctx();
|
||||
p.mark(m, std::move(rtx));
|
||||
}
|
||||
};
|
||||
|
||||
struct op : command<op>
|
||||
{
|
||||
bool show_ops = false;
|
||||
std::string op_name{};
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
|
||||
ap(show_ops,
|
||||
{"--list", "-l"},
|
||||
ap.help("List all the operators of MIGraphX"),
|
||||
ap.set_value(true));
|
||||
}
|
||||
void run() const
|
||||
{
|
||||
if(show_ops)
|
||||
{
|
||||
for(const auto& name : get_operators())
|
||||
std::cout << name << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto op = load_op(op_name);
|
||||
std::cout << op_name << ": " << std::endl;
|
||||
std::cout << to_pretty_json_string(op.to_value()) << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct onnx : command<onnx>
|
||||
{
|
||||
bool show_ops = false;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(show_ops,
|
||||
{"--list", "-l"},
|
||||
ap.help("List all onnx operators supported by MIGraphX"),
|
||||
ap.set_value(true));
|
||||
}
|
||||
void run() const
|
||||
{
|
||||
if(show_ops)
|
||||
{
|
||||
for(const auto& name : get_onnx_operators())
|
||||
std::cout << name << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct tf : command<tf>
|
||||
{
|
||||
bool show_ops = false;
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
ap(show_ops,
|
||||
{"--list", "-l"},
|
||||
ap.help("List all tf operators supported by MIGraphX"),
|
||||
ap.set_value(true));
|
||||
}
|
||||
void run() const
|
||||
{
|
||||
if(show_ops)
|
||||
{
|
||||
for(const auto& name : get_tf_operators())
|
||||
std::cout << name << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct main_command
|
||||
{
|
||||
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
|
||||
"COMMANDS:"))
|
||||
{
|
||||
std::string result = title + "\n";
|
||||
std::vector<std::string> commands(get_commands().size());
|
||||
std::transform(get_commands().begin(),
|
||||
get_commands().end(),
|
||||
commands.begin(),
|
||||
[](const auto& p) { return colorize(color::fg_green, p.first); });
|
||||
std::sort(commands.begin(), commands.end());
|
||||
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
|
||||
return r + " " + s + "\n";
|
||||
});
|
||||
}
|
||||
void parse(argument_parser& ap)
|
||||
{
|
||||
std::string version_str = get_version();
|
||||
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
|
||||
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
|
||||
ap(nullptr,
|
||||
{"-v", "--version"},
|
||||
ap.help("Show MIGraphX version"),
|
||||
ap.show_help(version_str));
|
||||
ap(nullptr, {"--ort-sha"}, ap.help("Show MIGraphX onnx runtime SHA"));
|
||||
|
||||
// Trim command off of exe name
|
||||
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
|
||||
ap.set_exe_name_to(exe_name);
|
||||
}
|
||||
|
||||
std::vector<std::string> wrong_commands{};
|
||||
std::string exe_name = "<exe>";
|
||||
|
||||
void run()
|
||||
{
|
||||
std::cout << color::fg_red << color::bold << "error: " << color::reset;
|
||||
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
|
||||
return get_commands().count(c) > 0;
|
||||
});
|
||||
if(it == wrong_commands.end())
|
||||
{
|
||||
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
|
||||
<< "' is not a valid command." << std::endl;
|
||||
std::cout << get_command_help("Available commands:");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "command '" << color::fg_yellow << *it << color::reset
|
||||
<< "' must be first argument" << std::endl;
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
|
||||
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
using namespace migraphx::driver; // NOLINT
|
||||
int main(int argc, const char* argv[], const char* envp[])
|
||||
{
|
||||
std::vector<std::string> args(argv + 1, argv + argc);
|
||||
// no argument, print the help infomration by default
|
||||
if(args.empty())
|
||||
{
|
||||
args.push_back("-h");
|
||||
}
|
||||
|
||||
auto&& m = get_commands();
|
||||
auto cmd = args.front();
|
||||
|
||||
if(cmd == "--ort-sha")
|
||||
{
|
||||
std::cout << MIGRAPHX_ORT_SHA1 << std::endl;
|
||||
return 0;
|
||||
}
|
||||
if(cmd == "-v" or cmd == "--version")
|
||||
{
|
||||
std::cout << get_version() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(m.count(cmd) > 0)
|
||||
{
|
||||
std::string driver_invocation =
|
||||
std::string(argv[0]) + " " + migraphx::to_string_range(args, " ");
|
||||
std::cout << "Running [ " << get_version() << " ]: " << driver_invocation << std::endl;
|
||||
|
||||
for(const char** env = envp; *env != nullptr; ++env)
|
||||
{
|
||||
std::string env_var(*env);
|
||||
size_t pos = env_var.find('=');
|
||||
if(pos != std::string::npos)
|
||||
{
|
||||
std::string key = env_var.substr(0, pos);
|
||||
if(key.find("MIGRAPHX") != std::string::npos)
|
||||
{
|
||||
std::cout << env_var << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.at(cmd)(argv[0],
|
||||
{args.begin() + 1, args.end()}); // run driver command found in commands map
|
||||
|
||||
std::cout << "[ " << get_version() << " ] Complete: " << driver_invocation << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
run_command<main_command>(argv[0], args);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
72
docker/rocm/migraphx/driver/marker_roctx.cpp
Normal file
72
docker/rocm/migraphx/driver/marker_roctx.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include "marker_roctx.hpp"
|
||||
|
||||
#include <migraphx/dynamic_loader.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
class marker_roctx
|
||||
{
|
||||
std::function<void(const char*)> sym_roctx_mark;
|
||||
std::function<uint64_t(const char*)> sym_roctx_range_start;
|
||||
std::function<void(uint64_t)> sym_roctx_range_stop;
|
||||
|
||||
std::function<int(const char*)> sym_roctx_range_push;
|
||||
std::function<int()> sym_roctx_range_pop;
|
||||
|
||||
uint64_t range_id = 0;
|
||||
|
||||
public:
|
||||
marker_roctx()
|
||||
{
|
||||
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
|
||||
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
|
||||
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
|
||||
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
|
||||
|
||||
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
|
||||
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
|
||||
|
||||
sym_roctx_mark("rocTX marker created.");
|
||||
}
|
||||
|
||||
void mark_start(instruction_ref ins_ref)
|
||||
{
|
||||
std::string text = "Marker start: " + ins_ref->name();
|
||||
sym_roctx_range_push(text.c_str());
|
||||
}
|
||||
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
|
||||
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
|
||||
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
|
||||
};
|
||||
|
||||
marker create_marker_roctx() { return marker_roctx(); }
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
39
docker/rocm/migraphx/driver/marker_roctx.hpp
Normal file
39
docker/rocm/migraphx/driver/marker_roctx.hpp
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
|
||||
|
||||
#include <migraphx/marker.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
marker create_marker_roctx();
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
59
docker/rocm/migraphx/driver/mlir.cpp
Normal file
59
docker/rocm/migraphx/driver/mlir.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
#include "mlir.hpp"
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/param_utils.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void offload_to_mlir(program& p)
|
||||
{
|
||||
auto* mm = p.get_main_module();
|
||||
auto* mlirm = p.create_module("mlir");
|
||||
mlirm->set_bypass();
|
||||
std::vector<instruction_ref> inputs;
|
||||
copy_if(iterator_for(*mm), std::back_inserter(inputs), [&](instruction_ref ins) {
|
||||
if(ins->name() == "@param")
|
||||
return true;
|
||||
if(ins->name() == "@literal")
|
||||
return ins->get_shape().elements() != 1;
|
||||
return false;
|
||||
});
|
||||
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
std::size_t n = 0;
|
||||
for(auto ins : inputs)
|
||||
{
|
||||
map_ins[ins] = mlirm->add_parameter(param_name(n++), ins->get_shape().as_standard());
|
||||
}
|
||||
|
||||
auto mlir_last = mlirm->add_instructions(mm, &map_ins);
|
||||
mlirm->add_return(mlir_last);
|
||||
|
||||
auto last = std::prev(mm->end());
|
||||
auto mlir_op = mm->insert_instruction(last, make_op("gpu::mlir_op"), inputs, {mlirm});
|
||||
if(mlir_last.size() > 1)
|
||||
{
|
||||
std::vector<instruction_ref> outputs;
|
||||
transform(range(mlir_last.size()), std::back_inserter(outputs), [&](auto i) {
|
||||
return mm->insert_instruction(last, make_op("get_tuple_elem", {{"index", i}}), mlir_op);
|
||||
});
|
||||
mm->replace_return(outputs);
|
||||
}
|
||||
else
|
||||
{
|
||||
mm->replace_return({mlir_op});
|
||||
}
|
||||
run_passes(*mm, {dead_code_elimination{}});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
16
docker/rocm/migraphx/driver/mlir.hpp
Normal file
16
docker/rocm/migraphx/driver/mlir.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
#ifndef MIGRAPHX_GUARD_DRIVER_MLIR_HPP
|
||||
#define MIGRAPHX_GUARD_DRIVER_MLIR_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void offload_to_mlir(program& p);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_DRIVER_MLIR_HPP
|
||||
45
docker/rocm/migraphx/driver/models.cpp
Normal file
45
docker/rocm/migraphx/driver/models.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "models.hpp"
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
migraphx::program test_gemm()
|
||||
{
|
||||
migraphx::program p;
|
||||
auto* mm = p.get_main_module();
|
||||
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
|
||||
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
|
||||
mm->add_instruction(migraphx::make_op("dot"), a, b);
|
||||
return p;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
35
docker/rocm/migraphx/driver/models.hpp
Normal file
35
docker/rocm/migraphx/driver/models.hpp
Normal file
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/program.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
migraphx::program test_gemm();
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
113
docker/rocm/migraphx/driver/passes.cpp
Normal file
113
docker/rocm/migraphx/driver/passes.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "passes.hpp"
|
||||
|
||||
#include <migraphx/auto_contiguous.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/eliminate_allocation.hpp>
|
||||
#include <migraphx/eliminate_common_subexpression.hpp>
|
||||
#include <migraphx/eliminate_concat.hpp>
|
||||
#include <migraphx/eliminate_contiguous.hpp>
|
||||
#include <migraphx/eliminate_data_type.hpp>
|
||||
#include <migraphx/eliminate_identity.hpp>
|
||||
#include <migraphx/eliminate_pad.hpp>
|
||||
#include <migraphx/fuse_pointwise.hpp>
|
||||
#include <migraphx/fuse_reduce.hpp>
|
||||
#include <migraphx/inline_module.hpp>
|
||||
#include <migraphx/insert_pad.hpp>
|
||||
#include <migraphx/normalize_ops.hpp>
|
||||
#include <migraphx/optimize_module.hpp>
|
||||
#include <migraphx/promote_literals.hpp>
|
||||
#include <migraphx/propagate_constant.hpp>
|
||||
#include <migraphx/rewrite_gelu.hpp>
|
||||
#include <migraphx/rewrite_pooling.hpp>
|
||||
#include <migraphx/rewrite_quantization.hpp>
|
||||
#include <migraphx/rewrite_rnn.hpp>
|
||||
#include <migraphx/simplify_algebra.hpp>
|
||||
#include <migraphx/simplify_dyn_ops.hpp>
|
||||
#include <migraphx/simplify_qdq.hpp>
|
||||
#include <migraphx/simplify_reshapes.hpp>
|
||||
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::unordered_map<std::string, pass> create_passes_lookup()
|
||||
{
|
||||
std::unordered_map<std::string, pass> result;
|
||||
// clang-format off
|
||||
std::initializer_list<pass> passes = {
|
||||
auto_contiguous{},
|
||||
dead_code_elimination{},
|
||||
eliminate_allocation{},
|
||||
eliminate_common_subexpression{},
|
||||
eliminate_concat{},
|
||||
eliminate_contiguous{},
|
||||
eliminate_data_type{},
|
||||
eliminate_identity{},
|
||||
eliminate_pad{},
|
||||
fuse_pointwise{},
|
||||
fuse_reduce{},
|
||||
inline_module{},
|
||||
insert_pad{},
|
||||
normalize_ops{},
|
||||
optimize_module{},
|
||||
promote_literals{},
|
||||
propagate_constant{},
|
||||
rewrite_gelu{},
|
||||
rewrite_pooling{},
|
||||
rewrite_quantization{},
|
||||
rewrite_rnn{},
|
||||
simplify_algebra{},
|
||||
simplify_dyn_ops{},
|
||||
simplify_qdq{},
|
||||
simplify_reshapes{},
|
||||
};
|
||||
// clang-format on
|
||||
for(const auto& pass : passes)
|
||||
result[pass.name()] = pass;
|
||||
result["eliminate_dead_code"] = dead_code_elimination{};
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<pass> get_passes(const std::vector<std::string>& names)
|
||||
{
|
||||
std::vector<pass> result;
|
||||
static const std::unordered_map<std::string, pass> lookup = create_passes_lookup();
|
||||
std::transform(
|
||||
names.begin(), names.end(), std::back_inserter(result), [](const std::string& name) {
|
||||
if(not contains(lookup, name))
|
||||
MIGRAPHX_THROW("Unknown pass: " + name);
|
||||
return lookup.at(name);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
40
docker/rocm/migraphx/driver/passes.hpp
Normal file
40
docker/rocm/migraphx/driver/passes.hpp
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
|
||||
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
|
||||
|
||||
#include <migraphx/pass.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::vector<pass> get_passes(const std::vector<std::string>& names);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
157
docker/rocm/migraphx/driver/perf.cpp
Normal file
157
docker/rocm/migraphx/driver/perf.cpp
Normal file
@ -0,0 +1,157 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include "perf.hpp"
|
||||
|
||||
#include <migraphx/generate.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/register_target.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/time.hpp>
|
||||
#ifdef HAVE_GPU
|
||||
#include <migraphx/gpu/hip.hpp>
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using milliseconds = std::chrono::duration<double, std::milli>;
|
||||
|
||||
template <class T>
|
||||
auto get_hash(const T& x)
|
||||
{
|
||||
return std::hash<T>{}(x);
|
||||
}
|
||||
|
||||
parameter_map fill_param_map(parameter_map& m,
|
||||
const std::unordered_map<std::string, shape>& param_shapes,
|
||||
const target& t,
|
||||
bool offload)
|
||||
{
|
||||
for(auto&& x : param_shapes)
|
||||
{
|
||||
argument& arg = m[x.first];
|
||||
if(arg.empty())
|
||||
{
|
||||
assert(not x.second.dynamic());
|
||||
arg = generate_argument(x.second, get_hash(x.first), random_mode::random);
|
||||
}
|
||||
if(not offload)
|
||||
arg = t.copy_to(arg);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
parameter_map create_param_map(const program& p, const target& t, bool offload)
|
||||
{
|
||||
parameter_map m;
|
||||
for(auto&& x : p.get_parameter_shapes())
|
||||
{
|
||||
auto arg = generate_argument(x.second, get_hash(x.first), random_mode::random);
|
||||
if(offload)
|
||||
m[x.first] = arg;
|
||||
else
|
||||
m[x.first] = t.copy_to(arg);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
parameter_map create_param_map(const program& p, bool gpu)
|
||||
{
|
||||
parameter_map m;
|
||||
for(auto&& x : p.get_parameter_shapes())
|
||||
{
|
||||
#ifdef HAVE_GPU
|
||||
if(gpu)
|
||||
m[x.first] =
|
||||
gpu::to_gpu(generate_argument(x.second, get_hash(x.first), random_mode::random));
|
||||
else
|
||||
#else
|
||||
(void)gpu;
|
||||
#endif
|
||||
m[x.first] = generate_argument(x.second, get_hash(x.first), random_mode::random);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
target get_target(bool gpu)
|
||||
{
|
||||
if(gpu)
|
||||
return make_target("gpu");
|
||||
else
|
||||
return make_target("cpu");
|
||||
}
|
||||
|
||||
bool is_offload_copy_set(const program& p)
|
||||
{
|
||||
assert(p.is_compiled());
|
||||
const module* mm = p.get_main_module();
|
||||
std::vector<std::string> param_names = mm->get_parameter_names();
|
||||
std::unordered_set<instruction_ref> param_ins;
|
||||
std::transform(param_names.begin(),
|
||||
param_names.end(),
|
||||
std::inserter(param_ins, param_ins.begin()),
|
||||
[&](const auto& i) { return mm->get_parameter(i); });
|
||||
for(const auto& i : *mm)
|
||||
{
|
||||
if(i.name() == "hip::copy_to_gpu")
|
||||
{
|
||||
auto copy_arg = instruction::get_output_alias(i.inputs().front(), true);
|
||||
param_ins.erase(copy_arg);
|
||||
}
|
||||
else if(i.name() == "@return")
|
||||
{
|
||||
auto return_args = i.inputs();
|
||||
for(const auto& j : return_args)
|
||||
{
|
||||
auto alias_ins = instruction::get_output_alias(j, true);
|
||||
if((alias_ins->name() == "@param" and param_ins.erase(alias_ins) == 0) or
|
||||
(alias_ins->name() != "hip::copy_from_gpu"))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return param_ins.empty();
|
||||
}
|
||||
|
||||
double time_run(const program& p, const parameter_map& m, int n)
|
||||
{
|
||||
// Run once without timing
|
||||
p.eval(m);
|
||||
p.finish();
|
||||
double total = time<milliseconds>([&] {
|
||||
for(auto i : range(n))
|
||||
{
|
||||
(void)i;
|
||||
p.eval(m);
|
||||
}
|
||||
p.finish();
|
||||
});
|
||||
return total / n;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
58
docker/rocm/migraphx/driver/perf.hpp
Normal file
58
docker/rocm/migraphx/driver/perf.hpp
Normal file
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_PERF_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_PERF_HPP
|
||||
|
||||
#include <migraphx/program.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
parameter_map fill_param_map(parameter_map& m,
|
||||
const std::unordered_map<std::string, shape>& param_shapes,
|
||||
const target& t,
|
||||
bool offload = false);
|
||||
parameter_map create_param_map(const program& p, const target& t, bool offload = false);
|
||||
|
||||
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
|
||||
parameter_map create_param_map(const program& p, bool gpu = true);
|
||||
target get_target(bool gpu);
|
||||
/**
|
||||
* @brief Checks if MIGraphX program compiled for "GPU" has offload_copy set of not. This is
|
||||
intended to print a HINT for the users and would not always correctly classify compiled program as
|
||||
with or without offload_copy in all cases.
|
||||
|
||||
* @param p Compiled MIGraphX program for GPU backend
|
||||
* @return true if program is classified as compiled with "offload_copy" set
|
||||
*/
|
||||
bool is_offload_copy_set(const program& p);
|
||||
|
||||
double time_run(const program& p, const parameter_map& m, int n = 100);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
43
docker/rocm/migraphx/driver/precision.hpp
Normal file
43
docker/rocm/migraphx/driver/precision.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
enum class precision
|
||||
{
|
||||
fp32,
|
||||
fp16,
|
||||
bf16,
|
||||
int8
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
340
docker/rocm/migraphx/driver/verify.cpp
Normal file
340
docker/rocm/migraphx/driver/verify.cpp
Normal file
@ -0,0 +1,340 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include "verify.hpp"
|
||||
#include "perf.hpp"
|
||||
|
||||
#include <migraphx/register_target.hpp>
|
||||
#include <migraphx/generate.hpp>
|
||||
#include <migraphx/verify_args.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/compile_options.hpp>
|
||||
#include <migraphx/quantization.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/fp_to_double.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
|
||||
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
|
||||
* model.
|
||||
*/
|
||||
verify::tolerance get_tolerances(const program& p,
|
||||
verify_options vo,
|
||||
std::optional<double> rms_tol,
|
||||
std::optional<double> atol,
|
||||
std::optional<double> rtol)
|
||||
{
|
||||
bool has_16bit = any_of(p.get_modules(), [](auto&& m) {
|
||||
return any_of(*m, [](auto&& ins) {
|
||||
return (ins.get_shape().type() == shape::half_type or
|
||||
ins.get_shape().type() == shape::bf16_type);
|
||||
});
|
||||
});
|
||||
migraphx::verify::tolerance result{};
|
||||
if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
|
||||
{
|
||||
result.rms_tol = 8e-2;
|
||||
result.atol = 4e-2;
|
||||
result.rtol = 4e-2;
|
||||
}
|
||||
if(rms_tol)
|
||||
{
|
||||
result.rms_tol = *rms_tol;
|
||||
}
|
||||
if(atol)
|
||||
{
|
||||
result.atol = *atol;
|
||||
}
|
||||
if(rtol)
|
||||
{
|
||||
result.rtol = *rtol;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<argument> run_ref(program p,
|
||||
const compile_options& options,
|
||||
const verify_options& vo,
|
||||
const parameter_map& inputs)
|
||||
{
|
||||
if(vo.ref_use_double)
|
||||
{
|
||||
run_passes(p, {fp_to_double{}});
|
||||
}
|
||||
p.compile(migraphx::make_target("ref"), options);
|
||||
auto out = p.eval(inputs);
|
||||
std::cout << p << std::endl;
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<argument> run_target(program p,
|
||||
const target& t,
|
||||
const compile_options& options,
|
||||
const verify_options& vo,
|
||||
const parameter_map& inputs)
|
||||
{
|
||||
if(vo.quantize == precision::fp16)
|
||||
{
|
||||
quantize_fp16(p);
|
||||
}
|
||||
if(vo.quantize == precision::bf16)
|
||||
{
|
||||
quantize_bf16(p);
|
||||
}
|
||||
p.compile(t, options);
|
||||
|
||||
parameter_map m;
|
||||
for(auto&& x : p.get_parameter_shapes())
|
||||
{
|
||||
auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first);
|
||||
m[x.first] = options.offload_copy ? arg : t.copy_to(arg);
|
||||
}
|
||||
auto gpu_out = p.eval(m);
|
||||
std::vector<argument> output(gpu_out.size());
|
||||
std::cout << p << std::endl;
|
||||
std::transform(gpu_out.begin(), gpu_out.end(), output.begin(), [&](auto& argu) {
|
||||
return options.offload_copy ? argu : t.copy_from(argu);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
bool verify_program(const std::string& name,
|
||||
const program& p,
|
||||
const target& t,
|
||||
compile_options options,
|
||||
verify_options vo,
|
||||
const parameter_map& inputs,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
auto ref_outs = run_ref(p, options, vo, inputs);
|
||||
auto target_outs = run_target(p, t, options, vo, inputs);
|
||||
|
||||
std::size_t output_num = ref_outs.size();
|
||||
bool passed = true;
|
||||
for(std::size_t i = 0; i < output_num; ++i)
|
||||
{
|
||||
if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
|
||||
ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
|
||||
{
|
||||
std::cout << "FAILED: " << name << std::endl;
|
||||
std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
|
||||
<< target_outs[i].get_shape() << "}" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
passed &= verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
|
||||
}
|
||||
}
|
||||
if(passed)
|
||||
std::cout << "MIGraphX verification passed successfully." << std::endl;
|
||||
return passed;
|
||||
}
|
||||
|
||||
void verify_instructions(const program& prog,
|
||||
const target& t,
|
||||
compile_options options,
|
||||
verify_options vo,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
const auto* mm_prog = prog.get_main_module();
|
||||
for(auto&& ins : (*mm_prog))
|
||||
{
|
||||
if(ins.name().front() == '@')
|
||||
continue;
|
||||
if(ins.name() == "broadcast")
|
||||
continue;
|
||||
if(ins.name() == "transpose")
|
||||
continue;
|
||||
if(ins.name() == "reshape")
|
||||
continue;
|
||||
if(ins.name() == "undefined")
|
||||
continue;
|
||||
program p;
|
||||
auto* mm_p = p.get_main_module();
|
||||
std::vector<instruction_ref> inputs;
|
||||
for(auto&& arg : ins.inputs())
|
||||
{
|
||||
if(arg->name() == "@literal")
|
||||
inputs.push_back(mm_p->add_literal(arg->get_literal()));
|
||||
else
|
||||
inputs.push_back(
|
||||
mm_p->add_parameter(std::to_string(inputs.size()), arg->get_shape()));
|
||||
}
|
||||
mm_p->add_instruction(ins.get_operator(), inputs);
|
||||
try
|
||||
{
|
||||
std::cout << "Verify: " << ins.name() << std::endl;
|
||||
std::cout << p << std::endl;
|
||||
verify_program(ins.name(), p, t, options, vo, create_param_map(p, false), tols);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool verify_reduced(program p,
|
||||
int n,
|
||||
const target& t,
|
||||
compile_options options,
|
||||
verify_options vo,
|
||||
const parameter_map& inputs,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
auto* mm = p.get_main_module();
|
||||
auto last = std::prev(mm->end(), n);
|
||||
mm->remove_instructions(last, mm->end());
|
||||
std::cout << "Verify: " << n << std::endl;
|
||||
std::cout << p << std::endl;
|
||||
try
|
||||
{
|
||||
return verify_program(std::to_string(n), p, t, options, vo, inputs, tols);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cout << "FAILED: " << n << std::endl;
|
||||
std::cout << "Exception: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void verify_reduced_program(const program& p,
|
||||
const target& t,
|
||||
compile_options options,
|
||||
verify_options vo,
|
||||
const parameter_map& inputs,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
const auto* mm = p.get_main_module();
|
||||
auto n = std::distance(mm->begin(), mm->end());
|
||||
std::cout << "Verify steps: " << n << std::endl;
|
||||
for(std::size_t i = 1; i < n; i++)
|
||||
{
|
||||
auto last = std::prev(mm->end(), i + 1);
|
||||
if(contains({"@literal", "@param"}, last->name()))
|
||||
{
|
||||
std::cout << "Skip: " << i << std::endl;
|
||||
continue;
|
||||
}
|
||||
verify_reduced(p, i, t, options, vo, inputs, tols);
|
||||
}
|
||||
}
|
||||
|
||||
static std::unordered_map<instruction_ref, std::size_t> accumulate_weights(instruction_ref last)
|
||||
{
|
||||
std::unordered_map<instruction_ref, std::size_t> weights;
|
||||
fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
|
||||
if(not contains(weights, ins))
|
||||
{
|
||||
if(ins->can_eval())
|
||||
return 0;
|
||||
std::size_t weight = 1;
|
||||
weights[ins] = std::accumulate(
|
||||
ins->inputs().begin(),
|
||||
ins->inputs().end(),
|
||||
weight,
|
||||
[&](std::size_t w, instruction_ref i) -> std::size_t { return w + self(i); });
|
||||
}
|
||||
return weights[ins];
|
||||
})(last);
|
||||
return weights;
|
||||
}
|
||||
|
||||
static optional<instruction_ref>
|
||||
get_parent(const std::unordered_map<instruction_ref, std::size_t>& weights, instruction_ref ins)
|
||||
{
|
||||
if(ins->inputs().empty())
|
||||
return nullopt;
|
||||
auto next = std::max_element(ins->inputs().begin(),
|
||||
ins->inputs().end(),
|
||||
by(std::less<>{}, [&](instruction_ref input) -> std::size_t {
|
||||
if(not contains(weights, input))
|
||||
return 0;
|
||||
return weights.at(input);
|
||||
}));
|
||||
return *next;
|
||||
}
|
||||
|
||||
static std::vector<std::size_t> find_trim_instructions(const module& m)
|
||||
{
|
||||
std::vector<std::size_t> result;
|
||||
auto last = std::prev(m.end());
|
||||
auto weights = accumulate_weights(last);
|
||||
auto next = get_parent(weights, last);
|
||||
std::size_t i = 0;
|
||||
while(auto parent = get_parent(weights, *next))
|
||||
{
|
||||
i += std::distance(*parent, *next);
|
||||
result.push_back(i + 1);
|
||||
next = parent;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void verify_bisected_program(const program& p,
|
||||
const target& t,
|
||||
compile_options options,
|
||||
verify_options vo,
|
||||
const parameter_map& inputs,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
const auto* mm = p.get_main_module();
|
||||
|
||||
std::vector<std::size_t> trims = find_trim_instructions(*mm);
|
||||
std::int64_t right = trims.size();
|
||||
std::int64_t left = 0;
|
||||
std::int64_t failed = -1;
|
||||
|
||||
while(left <= right)
|
||||
{
|
||||
std::int64_t mid = left + (right - left) / 2;
|
||||
assert(mid < trims.size() and mid >= 0);
|
||||
std::int64_t trim = trims.rbegin()[mid];
|
||||
bool passed = verify_reduced(p, trim, t, options, vo, inputs, tols);
|
||||
if(passed)
|
||||
{
|
||||
left = mid + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
failed = trim;
|
||||
right = mid - 1;
|
||||
}
|
||||
}
|
||||
if(failed > 0)
|
||||
{
|
||||
std::cout << "Failure starts at: " << failed << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
70
docker/rocm/migraphx/driver/verify.hpp
Normal file
70
docker/rocm/migraphx/driver/verify.hpp
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
|
||||
|
||||
#include "verify_options.hpp"
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/verify.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
verify::tolerance get_tolerances(const program& p,
|
||||
verify_options vo,
|
||||
std::optional<double> rms_tol,
|
||||
std::optional<double> atol,
|
||||
std::optional<double> rtol);
|
||||
|
||||
bool verify_program(const std::string& name,
|
||||
const program& p,
|
||||
const target& t,
|
||||
compile_options options = compile_options{},
|
||||
verify_options vo = verify_options{},
|
||||
const parameter_map& inputs = {},
|
||||
verify::tolerance tols = verify::tolerance{});
|
||||
void verify_instructions(const program& prog,
|
||||
const target& t,
|
||||
compile_options options = compile_options{},
|
||||
verify_options vo = verify_options{},
|
||||
verify::tolerance tols = verify::tolerance{});
|
||||
void verify_reduced_program(const program& p,
|
||||
const target& t,
|
||||
compile_options options = compile_options{},
|
||||
verify_options vo = verify_options{},
|
||||
const parameter_map& inputs = {},
|
||||
verify::tolerance tols = verify::tolerance{});
|
||||
void verify_bisected_program(const program& p,
|
||||
const target& t,
|
||||
compile_options options = compile_options{},
|
||||
verify_options vo = verify_options{},
|
||||
const parameter_map& inputs = {},
|
||||
verify::tolerance tols = verify::tolerance{});
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
48
docker/rocm/migraphx/driver/verify_options.hpp
Normal file
48
docker/rocm/migraphx/driver/verify_options.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_OPTIONS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_OPTIONS_HPP
|
||||
|
||||
#include "precision.hpp"
|
||||
|
||||
namespace migraphx {
|
||||
namespace driver {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct verify_options
|
||||
{
|
||||
/// Quantization precision
|
||||
precision quantize = precision::fp32;
|
||||
|
||||
/**
|
||||
* Converts floating point values to double on the ref target.
|
||||
*/
|
||||
bool ref_use_double = false;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace driver
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
45
docker/rocm/migraphx/include/migraphx/adjust_allocation.hpp
Normal file
45
docker/rocm/migraphx/include/migraphx/adjust_allocation.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/allocation_model.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
struct MIGRAPHX_EXPORT adjust_allocation
|
||||
{
|
||||
allocation_model model;
|
||||
std::string name() const { return "adjust_allocation"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
205
docker/rocm/migraphx/include/migraphx/algorithm.hpp
Normal file
205
docker/rocm/migraphx/include/migraphx/algorithm.hpp
Normal file
@ -0,0 +1,205 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class Iterator, class Output, class Predicate, class F>
|
||||
void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f)
|
||||
{
|
||||
while(start != last)
|
||||
{
|
||||
if(pred(*start))
|
||||
{
|
||||
*out = f(*start);
|
||||
++out;
|
||||
}
|
||||
++start;
|
||||
}
|
||||
}
|
||||
|
||||
/// Similiar to std::accumulate but a projection can be applied to the elements first
|
||||
template <class Iterator, class T, class BinaryOp, class UnaryOp>
|
||||
T transform_accumulate(Iterator first, Iterator last, T init, BinaryOp binop, UnaryOp unaryop)
|
||||
{
|
||||
return std::inner_product(
|
||||
first, last, first, init, binop, [&](auto&& x, auto&&) { return unaryop(x); });
|
||||
}
|
||||
|
||||
/// Similiar to std::partial_sum but a projection can be applied to the elements first
|
||||
template <class Iterator, class OutputIterator, class BinaryOperation, class UnaryOp>
|
||||
OutputIterator transform_partial_sum(
|
||||
Iterator first, Iterator last, OutputIterator d_first, BinaryOperation binop, UnaryOp unaryop)
|
||||
{
|
||||
if(first == last)
|
||||
return d_first;
|
||||
|
||||
auto acc = unaryop(*first);
|
||||
*d_first = acc;
|
||||
|
||||
while(++first != last)
|
||||
{
|
||||
acc = binop(std::move(acc), unaryop(*first));
|
||||
*++d_first = acc;
|
||||
}
|
||||
|
||||
return ++d_first;
|
||||
}
|
||||
|
||||
template <class Iterator, class Output, class Predicate>
|
||||
void group_by(Iterator start, Iterator last, Output out, Predicate pred)
|
||||
{
|
||||
while(start != last)
|
||||
{
|
||||
auto it = std::partition(start, last, [&](auto&& x) { return pred(x, *start); });
|
||||
out(start, it);
|
||||
start = it;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Iterator, class Output, class Predicate>
|
||||
void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
|
||||
{
|
||||
while(start != last)
|
||||
{
|
||||
auto it = std::find_if(start, last, [&](auto&& x) { return not pred(*start, x); });
|
||||
out(start, it);
|
||||
start = it;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Iterator, class Predicate, class Output>
|
||||
void group_find(Iterator start, Iterator last, Predicate pred, Output out)
|
||||
{
|
||||
start = std::find_if(start, last, pred);
|
||||
while(start != last)
|
||||
{
|
||||
auto it = std::find_if_not(start, last, pred);
|
||||
out(start, it);
|
||||
start = std::find_if(it, last, pred);
|
||||
}
|
||||
}
|
||||
|
||||
/// Similiar to std::remove_if but instead pass adjacent pairs to the predicate
|
||||
template <class Iterator, class Predicate>
|
||||
Iterator adjacent_remove_if(Iterator first, Iterator last, Predicate p)
|
||||
{
|
||||
first = std::adjacent_find(first, last, p);
|
||||
if(first == last)
|
||||
return first;
|
||||
auto i = first;
|
||||
while(std::next(++i) != last)
|
||||
{
|
||||
if(not p(*i, *std::next(i)))
|
||||
{
|
||||
*first = std::move(*i);
|
||||
++first;
|
||||
}
|
||||
}
|
||||
*first = std::move(*i);
|
||||
++first;
|
||||
return first;
|
||||
}
|
||||
|
||||
/// Similiar to std::for_each but instead pass adjacent pairs to the function
|
||||
template <class Iterator, class F>
|
||||
Iterator adjacent_for_each(Iterator first, Iterator last, F f)
|
||||
{
|
||||
if(first == last)
|
||||
return last;
|
||||
|
||||
Iterator next = first;
|
||||
++next;
|
||||
|
||||
for(; next != last; ++next, ++first)
|
||||
f(*first, *next);
|
||||
|
||||
return last;
|
||||
}
|
||||
|
||||
template <class Iterator1, class Iterator2>
|
||||
std::ptrdiff_t
|
||||
levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator2 last2)
|
||||
{
|
||||
if(first1 == last1)
|
||||
return std::distance(first2, last2);
|
||||
if(first2 == last2)
|
||||
return std::distance(first1, last1);
|
||||
if(*first1 == *first2)
|
||||
return levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
|
||||
auto x1 = levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
|
||||
auto x2 = levenshtein_distance(first1, last1, std::next(first2), last2);
|
||||
auto x3 = levenshtein_distance(std::next(first1), last1, first2, last2);
|
||||
return std::ptrdiff_t{1} + std::min({x1, x2, x3});
|
||||
}
|
||||
|
||||
inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
|
||||
{
|
||||
const size_t l1 = s1.length();
|
||||
const size_t l2 = s2.length();
|
||||
|
||||
if(l1 < l2)
|
||||
levenshtein_distance(s2, s1);
|
||||
|
||||
std::vector<size_t> d(l2 + 1);
|
||||
|
||||
std::iota(d.begin(), d.end(), 0);
|
||||
|
||||
for(size_t i = 1; i <= l1; i++)
|
||||
{
|
||||
size_t prev_cost = d[0];
|
||||
d[0] = i;
|
||||
|
||||
for(size_t j = 1; j <= l2; j++)
|
||||
{
|
||||
if(s1[i - 1] == s2[j - 1])
|
||||
{
|
||||
d[j] = prev_cost;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
|
||||
size_t cost_substitute = prev_cost;
|
||||
prev_cost = d[j];
|
||||
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return d[l2];
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
352
docker/rocm/migraphx/include/migraphx/allocation_model.hpp
Normal file
352
docker/rocm/migraphx/include/migraphx/allocation_model.hpp
Normal file
@ -0,0 +1,352 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP
|
||||
#define MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef DOXYGEN
|
||||
|
||||
/// An interface for target-dependent allocation
|
||||
struct allocation_model
|
||||
{
|
||||
/// A name of the target-dependent allocate operator
|
||||
std::string name() const;
|
||||
/// A name of the target-dependent copy operator
|
||||
std::string copy() const;
|
||||
/// Create an allocation operator for the given shape
|
||||
operation allocate(const shape& s) const;
|
||||
/// Create a preallocated operator for the given shape
|
||||
operation preallocate(const shape& s, const std::string& id) const;
|
||||
/// Check if outputs are to be inserted
|
||||
bool needs_out_params() const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
#ifdef TYPE_ERASED_DECLARATION
|
||||
|
||||
// Type-erased interface for:
|
||||
struct MIGRAPHX_EXPORT allocation_model
|
||||
{
|
||||
//
|
||||
std::string name() const;
|
||||
//
|
||||
std::string copy() const;
|
||||
//
|
||||
operation allocate(const shape& s) const;
|
||||
//
|
||||
operation preallocate(const shape& s, std::string id) const;
|
||||
//
|
||||
bool needs_out_params() const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct allocation_model
|
||||
{
|
||||
private:
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_pure = typename std::remove_cv<
|
||||
typename std::remove_reference<PrivateDetailTypeErasedT>::type>::type;
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints_impl =
|
||||
decltype(std::declval<PrivateDetailTypeErasedT>().name(),
|
||||
std::declval<PrivateDetailTypeErasedT>().copy(),
|
||||
std::declval<PrivateDetailTypeErasedT>().allocate(std::declval<const shape&>()),
|
||||
std::declval<PrivateDetailTypeErasedT>().preallocate(std::declval<const shape&>(),
|
||||
std::declval<std::string>()),
|
||||
std::declval<PrivateDetailTypeErasedT>().needs_out_params(),
|
||||
void());
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints = private_te_constraints_impl<
|
||||
typename private_te_unwrap_reference<private_te_pure<PrivateDetailTypeErasedT>>::type>;
|
||||
|
||||
public:
|
||||
// Constructors
|
||||
allocation_model() = default;
|
||||
|
||||
template <
|
||||
typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>, allocation_model>{}>::type>
|
||||
allocation_model(PrivateDetailTypeErasedT&& value)
|
||||
: private_detail_te_handle_mem_var(
|
||||
std::make_shared<
|
||||
private_detail_te_handle_type<private_te_pure<PrivateDetailTypeErasedT>>>(
|
||||
std::forward<PrivateDetailTypeErasedT>(value)))
|
||||
{
|
||||
}
|
||||
|
||||
// Assignment
|
||||
template <
|
||||
typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>, allocation_model>{}>::type>
|
||||
allocation_model& operator=(PrivateDetailTypeErasedT&& value)
|
||||
{
|
||||
using std::swap;
|
||||
auto* derived = this->any_cast<private_te_pure<PrivateDetailTypeErasedT>>();
|
||||
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
|
||||
{
|
||||
*derived = std::forward<PrivateDetailTypeErasedT>(value);
|
||||
}
|
||||
else
|
||||
{
|
||||
allocation_model rhs(value);
|
||||
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Cast
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
PrivateDetailTypeErasedT* any_cast()
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<const private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
const std::type_info& type_id() const
|
||||
{
|
||||
if(private_detail_te_handle_empty())
|
||||
return typeid(std::nullptr_t);
|
||||
else
|
||||
return private_detail_te_get_handle().type();
|
||||
}
|
||||
|
||||
std::string name() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().name();
|
||||
}
|
||||
|
||||
std::string copy() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().copy();
|
||||
}
|
||||
|
||||
operation allocate(const shape& s) const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().allocate(s);
|
||||
}
|
||||
|
||||
operation preallocate(const shape& s, std::string id) const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().preallocate(s, std::move(id));
|
||||
}
|
||||
|
||||
bool needs_out_params() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().needs_out_params();
|
||||
}
|
||||
|
||||
friend bool is_shared(const allocation_model& private_detail_x,
|
||||
const allocation_model& private_detail_y)
|
||||
{
|
||||
return private_detail_x.private_detail_te_handle_mem_var ==
|
||||
private_detail_y.private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_detail_te_handle_base_type
|
||||
{
|
||||
virtual ~private_detail_te_handle_base_type() {}
|
||||
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
|
||||
virtual const std::type_info& type() const = 0;
|
||||
|
||||
virtual std::string name() const = 0;
|
||||
virtual std::string copy() const = 0;
|
||||
virtual operation allocate(const shape& s) const = 0;
|
||||
virtual operation preallocate(const shape& s, std::string id) const = 0;
|
||||
virtual bool needs_out_params() const = 0;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type : private_detail_te_handle_base_type
|
||||
{
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
|
||||
nullptr)
|
||||
: private_detail_te_value(value)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
|
||||
int>::type* = nullptr) noexcept
|
||||
: private_detail_te_value(std::move(value))
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
|
||||
{
|
||||
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
|
||||
}
|
||||
|
||||
const std::type_info& type() const override { return typeid(private_detail_te_value); }
|
||||
|
||||
std::string name() const override { return private_detail_te_value.name(); }
|
||||
|
||||
std::string copy() const override { return private_detail_te_value.copy(); }
|
||||
|
||||
operation allocate(const shape& s) const override
|
||||
{
|
||||
|
||||
return private_detail_te_value.allocate(s);
|
||||
}
|
||||
|
||||
operation preallocate(const shape& s, std::string id) const override
|
||||
{
|
||||
|
||||
return private_detail_te_value.preallocate(s, std::move(id));
|
||||
}
|
||||
|
||||
bool needs_out_params() const override
|
||||
{
|
||||
|
||||
return private_detail_te_value.needs_out_params();
|
||||
}
|
||||
|
||||
PrivateDetailTypeErasedT private_detail_te_value;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
|
||||
{
|
||||
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
bool private_detail_te_handle_empty() const
|
||||
{
|
||||
return private_detail_te_handle_mem_var == nullptr;
|
||||
}
|
||||
|
||||
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private_detail_te_handle_base_type& private_detail_te_get_handle()
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
if(private_detail_te_handle_mem_var.use_count() > 1)
|
||||
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
|
||||
};
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType* any_cast(const allocation_model* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType* any_cast(allocation_model* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType& any_cast(allocation_model& x)
|
||||
{
|
||||
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType& any_cast(const allocation_model& x)
|
||||
{
|
||||
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
48
docker/rocm/migraphx/include/migraphx/analyze_streams.hpp
Normal file
48
docker/rocm/migraphx/include/migraphx/analyze_streams.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ANALYZE_STREAMS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ANALYZE_STREAMS_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/stream_model.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
struct stream_race
|
||||
{
|
||||
instruction_ref ins;
|
||||
// The instruction that should before
|
||||
instruction_ref before;
|
||||
};
|
||||
|
||||
MIGRAPHX_EXPORT std::vector<stream_race> analyze_streams(const module& m,
|
||||
const stream_model& strmm);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
84
docker/rocm/migraphx/include/migraphx/any_ptr.hpp
Normal file
84
docker/rocm/migraphx/include/migraphx/any_ptr.hpp
Normal file
@ -0,0 +1,84 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/optional.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/type_name.hpp>
|
||||
#include <cassert>
|
||||
#include <string_view>
|
||||
#include <typeindex>
|
||||
#include <type_traits>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct any_ptr
|
||||
{
|
||||
any_ptr() = default;
|
||||
template <class T>
|
||||
any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name<T*>())
|
||||
{
|
||||
}
|
||||
|
||||
any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {}
|
||||
|
||||
void* get(std::string_view n) const
|
||||
{
|
||||
if(name != n)
|
||||
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} +
|
||||
" != " + std::string{n});
|
||||
return ptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T get() const
|
||||
{
|
||||
static_assert(std::is_pointer<T>{}, "Must be a pointer");
|
||||
assert(ptr != nullptr);
|
||||
if(ti and std::type_index{typeid(T)} != *ti)
|
||||
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
|
||||
else if(name != get_name<T>())
|
||||
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
|
||||
return reinterpret_cast<T>(ptr);
|
||||
}
|
||||
void* unsafe_get() const { return ptr; }
|
||||
|
||||
private:
|
||||
void* ptr = nullptr;
|
||||
optional<std::type_index> ti = nullopt;
|
||||
std::string_view name = "";
|
||||
|
||||
template <class T>
|
||||
static const std::string& get_name()
|
||||
{
|
||||
return get_type_name<std::remove_cv_t<std::remove_pointer_t<T>>>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
|
||||
67
docker/rocm/migraphx/include/migraphx/apply_alpha_beta.hpp
Normal file
67
docker/rocm/migraphx/include/migraphx/apply_alpha_beta.hpp
Normal file
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
|
||||
|
||||
#include "migraphx/make_op.hpp"
|
||||
#include "migraphx/normalize_attributes.hpp"
|
||||
#include "migraphx/operation.hpp"
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_EXPORT
|
||||
instruction_ref insert_apply_alpha_beta(module& m,
|
||||
instruction_ref pos,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const operation& op,
|
||||
const literal& alpha,
|
||||
const literal& beta);
|
||||
|
||||
template <typename T = float>
|
||||
instruction_ref insert_apply_alpha_beta(module& m,
|
||||
instruction_ref pos,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const operation& op,
|
||||
T alpha = 1,
|
||||
T beta = 0)
|
||||
{
|
||||
return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}});
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
instruction_ref add_apply_alpha_beta(module& m,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const operation& op,
|
||||
T alpha = 1,
|
||||
T beta = 0)
|
||||
{
|
||||
return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP
|
||||
130
docker/rocm/migraphx/include/migraphx/argument.hpp
Normal file
130
docker/rocm/migraphx/include/migraphx/argument.hpp
Normal file
@ -0,0 +1,130 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ARGUMENT_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_ARGUMENT_HPP
|
||||
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/raw_data.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/make_shared_array.hpp>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
// clang-format off
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* @brief Arguments passed to instructions
|
||||
*
|
||||
* An `argument` can represent a raw buffer of data that either be referenced from another element
|
||||
* or it can be owned by the argument.
|
||||
*
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT argument : raw_data<argument>
|
||||
{
|
||||
argument() = default;
|
||||
|
||||
explicit argument(const shape& s);
|
||||
|
||||
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
|
||||
argument(shape s, F d)
|
||||
: m_shape(std::move(s))
|
||||
|
||||
{
|
||||
assign_buffer([f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); });
|
||||
}
|
||||
template <class T>
|
||||
argument(shape s, T* d)
|
||||
: m_shape(std::move(s))
|
||||
{
|
||||
assign_buffer([d] { return reinterpret_cast<char*>(d); });
|
||||
}
|
||||
|
||||
template <class T>
|
||||
argument(shape s, std::shared_ptr<T> d)
|
||||
: m_shape(std::move(s))
|
||||
{
|
||||
assign_buffer([d] { return reinterpret_cast<char*>(d.get()); });
|
||||
}
|
||||
|
||||
argument(shape s, std::nullptr_t);
|
||||
|
||||
argument(const std::vector<argument>& args);
|
||||
|
||||
/// Provides a raw pointer to the data
|
||||
char* data() const;
|
||||
|
||||
/// Whether data is available
|
||||
bool empty() const;
|
||||
|
||||
const shape& get_shape() const;
|
||||
|
||||
argument reshape(const shape& s) const;
|
||||
|
||||
argument copy() const;
|
||||
|
||||
/// Make copy of the argument that is always sharing the data
|
||||
argument share() const;
|
||||
|
||||
std::vector<argument> get_sub_objects() const;
|
||||
|
||||
/// Return the ith element
|
||||
argument element(std::size_t i) const;
|
||||
|
||||
// Keeps the same data ordering as the given container
|
||||
template <class Iterator>
|
||||
void fill(Iterator start, Iterator end)
|
||||
{
|
||||
assert(std::distance(start, end) <= m_shape.elements());
|
||||
this->visit([&](auto output) {
|
||||
std::copy(start, end, output.begin());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
void assign_buffer(std::function<char*()> d);
|
||||
struct data_t
|
||||
{
|
||||
std::function<char*()> get = nullptr;
|
||||
std::vector<data_t> sub = {};
|
||||
data_t share() const;
|
||||
static data_t from_args(const std::vector<argument>& args);
|
||||
};
|
||||
argument(const shape& s, const data_t& d);
|
||||
shape m_shape;
|
||||
data_t m_data{};
|
||||
};
|
||||
|
||||
MIGRAPHX_EXPORT std::vector<argument> flatten(const std::vector<argument>& args);
|
||||
|
||||
MIGRAPHX_EXPORT std::vector<shape> to_shapes(const std::vector<argument>& args);
|
||||
MIGRAPHX_EXPORT void migraphx_to_value(value& v, const argument& a);
|
||||
MIGRAPHX_EXPORT void migraphx_from_value(const value& v, argument& a);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
// clang-format on
|
||||
|
||||
#endif
|
||||
98
docker/rocm/migraphx/include/migraphx/array.hpp
Normal file
98
docker/rocm/migraphx/include/migraphx/array.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/requires.hpp>
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class R, class...>
|
||||
struct array_type
|
||||
{
|
||||
using type = R;
|
||||
};
|
||||
template <class... Ts>
|
||||
struct array_type<void, Ts...> : std::common_type<Ts...>
|
||||
{
|
||||
};
|
||||
|
||||
template <class R, class... Ts>
|
||||
using array_type_t = typename array_type<R, Ts...>::type;
|
||||
|
||||
template <class T, std::size_t N, std::size_t... I>
|
||||
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
|
||||
{
|
||||
return {{a[I]...}};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
|
||||
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
|
||||
{
|
||||
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
|
||||
}
|
||||
|
||||
constexpr std::array<int, 0> make_array() { return {}; }
|
||||
|
||||
template <class T, std::size_t N>
|
||||
constexpr auto to_array(T (&a)[N])
|
||||
{
|
||||
return detail::to_array_impl(a, detail::gens<N>{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <std::size_t Offset = 0, class Array, std::size_t... I>
|
||||
constexpr auto rearray_impl(Array a, seq<I...>)
|
||||
{
|
||||
return make_array(a[I + Offset]...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class T, std::size_t N>
|
||||
constexpr auto pop_front(std::array<T, N> a)
|
||||
{
|
||||
return detail::rearray_impl(a, detail::gens<N - 1>{});
|
||||
}
|
||||
|
||||
template <class T, std::size_t N>
|
||||
constexpr auto pop_back(std::array<T, N> a)
|
||||
{
|
||||
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
43
docker/rocm/migraphx/include/migraphx/as_number.hpp
Normal file
43
docker/rocm/migraphx/include/migraphx/as_number.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP
|
||||
|
||||
#include <cstdint>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T>
|
||||
T as_number(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
|
||||
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_RTGLIB_AS_NUMBER_HPP
|
||||
61
docker/rocm/migraphx/include/migraphx/assert.hpp
Normal file
61
docker/rocm/migraphx/include/migraphx/assert.hpp
Normal file
@ -0,0 +1,61 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class F>
|
||||
auto abort_on_throw(F f) -> decltype(f())
|
||||
{
|
||||
try
|
||||
{
|
||||
return f();
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << e.what() << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
std::cerr << "Unknown exception" << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
#ifdef NDEBUG
|
||||
#define MIGRAPHX_ASSERT_NO_THROW(...) __VA_ARGS__
|
||||
#else
|
||||
#define MIGRAPHX_ASSERT_NO_THROW(...) \
|
||||
migraphx::abort_on_throw([&]() -> decltype(__VA_ARGS__) { return __VA_ARGS__; })
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
|
||||
40
docker/rocm/migraphx/include/migraphx/assignment_options.hpp
Normal file
40
docker/rocm/migraphx/include/migraphx/assignment_options.hpp
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
|
||||
|
||||
#include <migraphx/support_metric.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct assignment_options
|
||||
{
|
||||
support_metric metric = support_metric::latency;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
|
||||
67
docker/rocm/migraphx/include/migraphx/auto_any_cast.hpp
Normal file
67
docker/rocm/migraphx/include/migraphx/auto_any_cast.hpp
Normal file
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
// Forward declare any_cast
|
||||
template <class T>
|
||||
const T& any_cast(const T&);
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class U>
|
||||
void any_cast()
|
||||
{
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct auto_any_caster
|
||||
{
|
||||
T& x; // NOLINT
|
||||
|
||||
template <class U>
|
||||
operator U&()
|
||||
{
|
||||
return any_cast<U>(x);
|
||||
}
|
||||
|
||||
operator T&() { return x; }
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class T>
|
||||
detail::auto_any_caster<T> auto_any_cast(T& x)
|
||||
{
|
||||
return {x};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
45
docker/rocm/migraphx/include/migraphx/auto_contiguous.hpp
Normal file
45
docker/rocm/migraphx/include/migraphx/auto_contiguous.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
struct MIGRAPHX_EXPORT auto_contiguous
|
||||
{
|
||||
std::string name() const { return "auto_contiguous"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
72
docker/rocm/migraphx/include/migraphx/auto_register.hpp
Normal file
72
docker/rocm/migraphx/include/migraphx/auto_register.hpp
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <type_traits>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class Action, class T>
|
||||
int auto_register_action()
|
||||
{
|
||||
Action::template apply<T>();
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class Action, class T>
|
||||
struct auto_register
|
||||
{
|
||||
const static int static_register;
|
||||
// This typedef ensures that the static member will be instantiated if
|
||||
// the class itself is instantiated
|
||||
using static_register_type =
|
||||
std::integral_constant<decltype(&static_register), &static_register>;
|
||||
};
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wglobal-constructors"
|
||||
#endif
|
||||
|
||||
template <class Action, class T>
|
||||
const int auto_register<Action, T>::static_register = auto_register_action<Action, T>(); // NOLINT
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
|
||||
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_AUTO_REGISTER(...) \
|
||||
[[maybe_unused]] void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)( \
|
||||
migraphx::auto_register<__VA_ARGS__> x = migraphx::auto_register<__VA_ARGS__>{});
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
48
docker/rocm/migraphx/include/migraphx/autocast_fp8.hpp
Normal file
48
docker/rocm/migraphx/include/migraphx/autocast_fp8.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_AUTOCAST_FP8_HPP
|
||||
#define MIGRAPHX_GUARD_AMDMIGRAPHX_AUTOCAST_FP8_HPP
|
||||
|
||||
#include <migraphx/shape.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct program;
|
||||
struct module;
|
||||
|
||||
/**
|
||||
This pass will convert model with fp8 input parameter to model with fp32
|
||||
input parameter and internally add casts to fp8 for those converted params.*/
|
||||
struct MIGRAPHX_EXPORT autocast_fp8_pass
|
||||
{
|
||||
shape::type_t target_type = migraphx::shape::float_type;
|
||||
std::string name() const { return "autocast_fp8_pass"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
39
docker/rocm/migraphx/include/migraphx/base64.hpp
Normal file
39
docker/rocm/migraphx/include/migraphx/base64.hpp
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_BASE64_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_BASE64_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/// encode string to base64
|
||||
std::string MIGRAPHX_EXPORT base64_encode(const std::string& str);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
39
docker/rocm/migraphx/include/migraphx/bf16.hpp
Normal file
39
docker/rocm/migraphx/include/migraphx/bf16.hpp
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_BF16_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_BF16_HPP
|
||||
|
||||
#include <migraphx/generic_float.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using bf16 = migraphx::generic_float<7, 8>;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
57
docker/rocm/migraphx/include/migraphx/bit_cast.hpp
Normal file
57
docker/rocm/migraphx/include/migraphx/bit_cast.hpp
Normal file
@ -0,0 +1,57 @@
|
||||
/* ************************************************************************
|
||||
* Copyright (C) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
|
||||
* ies of the Software, and to permit persons to whom the Software is furnished
|
||||
* to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
|
||||
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
|
||||
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
* ************************************************************************ */
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
|
||||
#include <type_traits>
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wduplicated-branches"
|
||||
#endif
|
||||
|
||||
#include <migraphx/requires.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
template <typename To,
|
||||
typename From,
|
||||
MIGRAPHX_REQUIRES(std::is_trivially_copyable<To>{} and
|
||||
std::is_trivially_copyable<From>{})>
|
||||
inline constexpr To bit_cast(From fr) noexcept
|
||||
{
|
||||
static_assert(sizeof(To) == sizeof(From));
|
||||
#if defined(__GNUC__) and !defined(__clang__)
|
||||
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
|
||||
#else
|
||||
return __builtin_bit_cast(To, fr);
|
||||
#endif
|
||||
}
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
|
||||
95
docker/rocm/migraphx/include/migraphx/bit_signal.hpp
Normal file
95
docker/rocm/migraphx/include/migraphx/bit_signal.hpp
Normal file
@ -0,0 +1,95 @@
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_BIT_SIGNAL_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_BIT_SIGNAL_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/// Observer pattern for keeping track of if something has changed or been
|
||||
/// updated. Can have up to `N` different subscribers. Use by creating a
|
||||
/// `bit_signal` and adding subscribers with `bit_signal.subscribe()`. Use
|
||||
/// `bit_signal.notify()` to set that subscribers should be notified. Get the
|
||||
/// status of the subscription by checking the `slot` returned by
|
||||
/// `bit_signal.subscribe()`.
|
||||
template <std::size_t N>
|
||||
struct bit_signal
|
||||
{
|
||||
std::bitset<N> slots;
|
||||
std::bitset<N> allocated;
|
||||
|
||||
struct slot
|
||||
{
|
||||
bit_signal* handler = nullptr;
|
||||
std::size_t i = N;
|
||||
|
||||
slot() = default;
|
||||
|
||||
slot(bit_signal* h, std::size_t x) : handler(h), i(x) {}
|
||||
|
||||
slot(slot&& rhs) noexcept : handler(rhs.handler), i(rhs.i)
|
||||
{
|
||||
rhs.handler = nullptr;
|
||||
rhs.i = N;
|
||||
}
|
||||
|
||||
slot(const slot& rhs) : handler(rhs.handler), i(rhs.handler->allocate()) {}
|
||||
|
||||
slot& operator=(slot rhs)
|
||||
{
|
||||
std::swap(handler, rhs.handler);
|
||||
std::swap(i, rhs.i);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~slot() noexcept
|
||||
{
|
||||
if(valid())
|
||||
handler->deallocate(i);
|
||||
}
|
||||
|
||||
bool valid() const { return i < N and handler != nullptr; }
|
||||
|
||||
bool triggered() const
|
||||
{
|
||||
assert(valid());
|
||||
return handler->triggered(i);
|
||||
}
|
||||
|
||||
operator bool() const { return triggered(); }
|
||||
};
|
||||
|
||||
slot subscribe() { return {this, allocate()}; }
|
||||
|
||||
std::size_t allocate()
|
||||
{
|
||||
auto i = *find_if(range(N), [&](auto x) { return not allocated[x]; });
|
||||
if(i == N)
|
||||
MIGRAPHX_THROW("Too many signals allocated");
|
||||
slots[i] = false;
|
||||
allocated[i] = true;
|
||||
return i;
|
||||
}
|
||||
|
||||
void deallocate(std::size_t i) { allocated[i] = false; }
|
||||
|
||||
void notify() { slots.set(); }
|
||||
|
||||
bool triggered(std::size_t i) const { return slots[i]; }
|
||||
|
||||
void clear()
|
||||
{
|
||||
slots.reset();
|
||||
allocated.reset();
|
||||
}
|
||||
|
||||
std::size_t nslots() const { return allocated.count(); }
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_BIT_SIGNAL_HPP
|
||||
114
docker/rocm/migraphx/include/migraphx/builtin.hpp
Normal file
114
docker/rocm/migraphx/include/migraphx/builtin.hpp
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_BUILTIN_HPP
|
||||
#define MIGRAPHX_GUARD_BUILTIN_HPP
|
||||
|
||||
#include <migraphx/context.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/argument.hpp>
|
||||
#include <migraphx/reflect.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
namespace builtin {
|
||||
|
||||
struct literal
|
||||
{
|
||||
std::string name() const { return "@literal"; }
|
||||
shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
|
||||
argument compute(context&, const shape&, const std::vector<argument>&) const
|
||||
{
|
||||
MIGRAPHX_THROW("builtin");
|
||||
}
|
||||
};
|
||||
|
||||
struct outline
|
||||
{
|
||||
shape s;
|
||||
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.s, "shape"));
|
||||
}
|
||||
|
||||
std::string name() const { return "@outline"; }
|
||||
shape compute_shape(const std::vector<shape>&) const { return s; }
|
||||
argument compute(context&, const shape&, const std::vector<argument>&) const
|
||||
{
|
||||
MIGRAPHX_THROW("builtin");
|
||||
}
|
||||
};
|
||||
|
||||
struct param
|
||||
{
|
||||
std::string parameter;
|
||||
uint32_t order = 0;
|
||||
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.parameter, "parameter"));
|
||||
}
|
||||
|
||||
std::string name() const { return "@param"; }
|
||||
shape compute_shape(const std::vector<shape>&) const { MIGRAPHX_THROW("builtin"); }
|
||||
argument compute(context&, const shape&, const std::vector<argument>&) const
|
||||
{
|
||||
MIGRAPHX_THROW("builtin");
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os, const param& op)
|
||||
{
|
||||
os << op.name() << ":" << op.parameter;
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct returns
|
||||
{
|
||||
std::string name() const { return "@return"; }
|
||||
|
||||
shape compute_shape(const std::vector<shape>& arg) const
|
||||
{
|
||||
if(arg.empty())
|
||||
return {};
|
||||
else if(arg.size() == 1)
|
||||
return arg[0];
|
||||
else
|
||||
return shape(arg);
|
||||
}
|
||||
|
||||
argument compute(context&, const shape&, const std::vector<argument>&) const
|
||||
{
|
||||
MIGRAPHX_THROW("builtin");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
87
docker/rocm/migraphx/include/migraphx/check_context.hpp
Normal file
87
docker/rocm/migraphx/include/migraphx/check_context.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
|
||||
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T>
|
||||
struct check_context
|
||||
{
|
||||
struct op : auto_register_op<op>
|
||||
{
|
||||
static std::string compute_op_name()
|
||||
{
|
||||
const auto& op_type_name = get_type_name<T>();
|
||||
const auto& split_name = split_string(op_type_name, ':');
|
||||
std::vector<std::string> name_without_version = {"check_context"};
|
||||
// op_type_name would contain internal namespace name with version_x_y_z
|
||||
// remove version and construct op_name such as check_context::migraphx::gpu::context
|
||||
std::copy_if(
|
||||
split_name.begin(),
|
||||
split_name.end(),
|
||||
std::back_inserter(name_without_version),
|
||||
[&](const auto& i) { return not i.empty() and not contains(i, "version"); });
|
||||
return join_strings(name_without_version, "::");
|
||||
}
|
||||
|
||||
std::string name() const
|
||||
{
|
||||
static auto op_name = compute_op_name();
|
||||
return op_name;
|
||||
}
|
||||
|
||||
shape compute_shape(const std::vector<shape>&) const { return {}; }
|
||||
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
|
||||
{
|
||||
this->check(ctx);
|
||||
return {};
|
||||
}
|
||||
void finalize(context& ctx, const shape&, const std::vector<shape>&) const
|
||||
{
|
||||
this->check(ctx);
|
||||
}
|
||||
void check(context& ctx) const
|
||||
{
|
||||
T* x = any_cast<T>(&ctx);
|
||||
if(x == nullptr)
|
||||
MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
|
||||
}
|
||||
};
|
||||
|
||||
std::string name() const { return "check_context"; }
|
||||
void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
431
docker/rocm/migraphx/include/migraphx/check_shapes.hpp
Normal file
431
docker/rocm/migraphx/include/migraphx/check_shapes.hpp
Normal file
@ -0,0 +1,431 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
|
||||
|
||||
#include <migraphx/permutation.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <algorithm>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
// Check that deduced type is incrementable, dereferencable, and comparable
|
||||
template <class, class = void>
|
||||
struct is_iterator
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_iterator<T,
|
||||
std::void_t<decltype(++std::declval<T&>()),
|
||||
decltype(*std::declval<T&>()),
|
||||
decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class Iterator>
|
||||
struct check_shapes
|
||||
{
|
||||
static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
|
||||
Iterator begin;
|
||||
Iterator end;
|
||||
std::string name;
|
||||
bool dynamic_allowed;
|
||||
|
||||
check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
|
||||
: begin(b), end(e), name(n), dynamic_allowed(d)
|
||||
{
|
||||
check_dynamic();
|
||||
}
|
||||
|
||||
template <class Op>
|
||||
check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
|
||||
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
|
||||
{
|
||||
check_dynamic();
|
||||
}
|
||||
|
||||
template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
|
||||
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
|
||||
: begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
|
||||
{
|
||||
check_dynamic();
|
||||
}
|
||||
|
||||
check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
|
||||
: begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
|
||||
{
|
||||
check_dynamic();
|
||||
}
|
||||
|
||||
void check_dynamic() const
|
||||
{
|
||||
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
|
||||
{
|
||||
MIGRAPHX_THROW(prefix() + "Dynamic shapes not supported");
|
||||
}
|
||||
}
|
||||
|
||||
std::string prefix() const
|
||||
{
|
||||
if(name.empty())
|
||||
return "";
|
||||
else
|
||||
return name + ": ";
|
||||
}
|
||||
|
||||
std::size_t size() const
|
||||
{
|
||||
if(begin == end)
|
||||
return 0;
|
||||
return end - begin;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Require the number of shape objects to equal to one of the
|
||||
* given sizes.
|
||||
* \param ns template parameter pack of sizes to check against
|
||||
*/
|
||||
template <class... Ts>
|
||||
const check_shapes& has(Ts... ns) const
|
||||
{
|
||||
if(migraphx::none_of({ns...}, [&](auto i) { return this->size() == i; }))
|
||||
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " +
|
||||
to_string_range({ns...}) + " but given " + std::to_string(size()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Require the number of shape objects to equal at least a given amount. Use this
|
||||
* method for ops that can take any number (variadic) of inputs.
|
||||
* \param n min. number of shapes
|
||||
*/
|
||||
const check_shapes& has_at_least(std::size_t n) const
|
||||
{
|
||||
if(this->size() < n)
|
||||
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
|
||||
to_string(n) + " but given " + std::to_string(size()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Require all shapes to have the same number of elements.
|
||||
* \param n number of
|
||||
*/
|
||||
const check_shapes& nelements(std::size_t n) const
|
||||
{
|
||||
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check that the first shape has exactly n dimensions.
|
||||
* Do nothing if the container is empty.
|
||||
* \param n number of dimensions
|
||||
*/
|
||||
const check_shapes& only_dims(std::size_t n) const
|
||||
{
|
||||
if(begin != end)
|
||||
{
|
||||
if(begin->ndim() != n)
|
||||
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check that the first shape has a maximum of n dimensions.
|
||||
* Do nothing if the container is empty.
|
||||
* \param n number of dimensions
|
||||
*/
|
||||
const check_shapes& max_ndims(std::size_t n) const
|
||||
{
|
||||
if(begin != end)
|
||||
{
|
||||
if(begin->ndim() > n)
|
||||
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
|
||||
" dimensions");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check that the first shape has a minimum of n dimensions.
|
||||
* Do nothing if the container is empty.
|
||||
* \param n number of dimensions
|
||||
*/
|
||||
const check_shapes& min_ndims(std::size_t n) const
|
||||
{
|
||||
if(begin != end)
|
||||
{
|
||||
if(begin->ndim() < n)
|
||||
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
|
||||
" dimensions");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same shape.
|
||||
*/
|
||||
const check_shapes& same_shape() const
|
||||
{
|
||||
if(not this->same([](const shape& s) { return s; }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes do not match");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same type.
|
||||
*/
|
||||
const check_shapes& same_type() const
|
||||
{
|
||||
if(not this->same([](const shape& s) { return s.type(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Types do not match");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same lens.
|
||||
*/
|
||||
const check_shapes& same_dims() const
|
||||
{
|
||||
if(not this->same([](const shape& s) { return s.max_lens(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
|
||||
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
|
||||
if(not this->same([](const shape& s) { return s.min_lens(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same number of dimensions.
|
||||
*/
|
||||
const check_shapes& same_ndims() const
|
||||
{
|
||||
if(not this->same([](const shape& s) { return s.ndim(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same layout.
|
||||
*/
|
||||
const check_shapes& same_layout() const
|
||||
{
|
||||
if(not this->same([](const shape& s) { return find_permutation(s); }))
|
||||
MIGRAPHX_THROW(prefix() + "Layouts do not match");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are standard.
|
||||
*/
|
||||
const check_shapes& standard() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.standard(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are scalar.
|
||||
*/
|
||||
const check_shapes& scalar() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.scalar(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are standard or scalar.
|
||||
*/
|
||||
const check_shapes& standard_or_scalar() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are packed.
|
||||
*/
|
||||
const check_shapes& packed() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.packed(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not packed");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are packed with certain layouts
|
||||
*/
|
||||
const check_shapes&
|
||||
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
|
||||
{
|
||||
if(not this->all_of([&](const shape& s) {
|
||||
return s.packed() and contains(layouts, find_permutation(s));
|
||||
}))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are packed or broadcasted.
|
||||
*/
|
||||
const check_shapes& packed_or_broadcasted() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are tuples.
|
||||
*/
|
||||
const check_shapes& tuple_type() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are not transposed.
|
||||
*/
|
||||
const check_shapes& not_transposed() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return not s.transposed(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are transposed");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes are not broadcasted.
|
||||
*/
|
||||
const check_shapes& not_broadcasted() const
|
||||
{
|
||||
if(not this->all_of([](const shape& s) { return s.standard() or not s.broadcasted(); }))
|
||||
MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check all shapes have the same n elements.
|
||||
* \param n number of elements
|
||||
*/
|
||||
const check_shapes& elements(std::size_t n) const
|
||||
{
|
||||
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
|
||||
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check the batches of all the shapes do not have transposed strides.
|
||||
*/
|
||||
const check_shapes& batch_not_transposed() const
|
||||
{
|
||||
if(not this->all_of(
|
||||
[&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
|
||||
MIGRAPHX_THROW(prefix() + "Batch size is transposed");
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class F>
|
||||
bool same(F f) const
|
||||
{
|
||||
if(begin == end)
|
||||
return true;
|
||||
auto&& key = f(*begin);
|
||||
return this->all_of([&](const shape& s) { return f(s) == key; });
|
||||
}
|
||||
|
||||
template <class Predicate>
|
||||
bool all_of(Predicate p) const
|
||||
{
|
||||
if(begin == end)
|
||||
return true;
|
||||
return std::all_of(begin, end, p);
|
||||
}
|
||||
|
||||
template <class Predicate>
|
||||
bool any_of(Predicate p) const
|
||||
{
|
||||
if(begin == end)
|
||||
return false;
|
||||
return std::any_of(begin, end, p);
|
||||
}
|
||||
|
||||
Iterator get(long i) const
|
||||
{
|
||||
if(i >= size())
|
||||
MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
|
||||
if(i < 0)
|
||||
return end - i;
|
||||
return begin + i;
|
||||
}
|
||||
|
||||
check_shapes slice(long start) const { return {get(start), end, name}; }
|
||||
|
||||
check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
|
||||
|
||||
private:
|
||||
static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides)
|
||||
{
|
||||
if(strides.size() <= 2)
|
||||
return true;
|
||||
auto dim_0 = strides.size() - 2;
|
||||
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
|
||||
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
|
||||
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
|
||||
return (i < j or i < matrix_size or j < matrix_size);
|
||||
}) != batch.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Deduction guide for std::vector constructor
|
||||
template <class Op>
|
||||
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
|
||||
-> check_shapes<std::vector<shape>::const_iterator>;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
48
docker/rocm/migraphx/include/migraphx/clamp.hpp
Normal file
48
docker/rocm/migraphx/include/migraphx/clamp.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/float_equal.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class U, class T>
|
||||
U pad_clamp(T x)
|
||||
{
|
||||
if(float_equal(x, std::numeric_limits<T>::lowest()))
|
||||
return std::numeric_limits<U>::lowest();
|
||||
if(float_equal(x, std::numeric_limits<T>::max()))
|
||||
return std::numeric_limits<U>::max();
|
||||
return (x < std::numeric_limits<U>::lowest())
|
||||
? std::numeric_limits<U>::lowest()
|
||||
: (std::numeric_limits<U>::max() < x) ? std::numeric_limits<U>::max() : U(x);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
72
docker/rocm/migraphx/include/migraphx/cloneable.hpp
Normal file
72
docker/rocm/migraphx/include/migraphx/cloneable.hpp
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_CLONEABLE_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_CLONEABLE_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <memory>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <typename Base>
|
||||
struct cloneable
|
||||
{
|
||||
friend Base;
|
||||
|
||||
virtual std::shared_ptr<Base> clone() = 0;
|
||||
|
||||
template <typename Derived>
|
||||
struct derive : Base
|
||||
{
|
||||
friend Derived;
|
||||
|
||||
std::shared_ptr<Base> clone() override
|
||||
{
|
||||
return std::make_shared<Derived>(static_cast<const Derived&>(*this));
|
||||
}
|
||||
template <typename... Args>
|
||||
derive(Args&&... args) : Base(std::forward<Args>(args)...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct share : Base, std::enable_shared_from_this<Base>
|
||||
{
|
||||
std::shared_ptr<Base> clone() override { return this->shared_from_this(); }
|
||||
template <typename... Args>
|
||||
share(Args&&... args) : Base(std::forward<Args>(args)...)
|
||||
{
|
||||
}
|
||||
};
|
||||
cloneable() = default;
|
||||
cloneable(const cloneable&) = default;
|
||||
cloneable& operator=(const cloneable&) = default;
|
||||
virtual ~cloneable() {}
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
158
docker/rocm/migraphx/include/migraphx/common.hpp
Normal file
158
docker/rocm/migraphx/include/migraphx/common.hpp
Normal file
@ -0,0 +1,158 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
struct operation;
|
||||
|
||||
struct common_options
|
||||
{
|
||||
bool common_type = true;
|
||||
bool common_lens = true;
|
||||
};
|
||||
|
||||
/**
|
||||
* Broadcasting works by comparing the shapes element-wise starting with
|
||||
* the trailing (right-most) dimensions and working leftwards. This is equivalent
|
||||
* to what is done in NumPy.
|
||||
* example 1:
|
||||
* s0 = (3,2,4,5) and s1 = (2,1,1)
|
||||
* In this case we need to broadcast (:,1,1) portion of
|
||||
* s1 plus broadcast the 1st dimension of s0
|
||||
* giving output_lens = (3,2,4,5)
|
||||
*
|
||||
* example 2:
|
||||
* s0 = (3,2,1,5) and s1 = (2,7,5)
|
||||
* In this case we need to broadcast the (:,:,1:,:) axis
|
||||
* of s0 plus the 1st dimension of s1 giving
|
||||
* output_lens = (3,2,7,5)
|
||||
*
|
||||
* example 3:
|
||||
* s0 = (4, 1, 1) and s1 = (3, 4)
|
||||
* output_lens = (4, 3, 4)
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
|
||||
std::vector<std::size_t> s1);
|
||||
|
||||
/**
|
||||
* Broadcasting for two vectors of dynamic_dimensions.
|
||||
* Compares `dynamic_dimension` objects from the trailing (right-most) dimension and working
|
||||
* leftwards.
|
||||
*
|
||||
* Rules for broadcasting dynamic_dimension:
|
||||
* If the same `dynamic_dimension`, return either.
|
||||
* If one of the `dynamic_dimension`s is 1, return the other one.
|
||||
* If the `dynamic_dimension`s have an intersection, return the intersection.
|
||||
* Explanation:
|
||||
* For the shape to be broadcastable at runtime (when the dimensions are constant) the dimensions
|
||||
* must be the same. The only way for the dimensions to be the same is if the output dimension is
|
||||
* the intersection of the ranges.
|
||||
* In practice, we will mostly see this case for handling unknown dynamic_dimensions like {0,
|
||||
* max_int}. Else, throw an error.
|
||||
*
|
||||
* There is a contrived edge case for ranges that include 1 but are not a fixed {1, 1}.
|
||||
* That case is not supported.
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<shape::dynamic_dimension>
|
||||
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
|
||||
std::vector<shape::dynamic_dimension> dds1);
|
||||
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
|
||||
|
||||
MIGRAPHX_EXPORT
|
||||
shape common_shape(const std::vector<shape>& shapes);
|
||||
|
||||
/**
|
||||
* @brief Compute the common (broadcasted) dimensions of a list of fixed shapes
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes);
|
||||
|
||||
/**
|
||||
* @ brief Compute the common (broadcasted) dynamic dimensions of a list of dynamic shapes
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes);
|
||||
|
||||
/**
|
||||
* @brief Creates and adds instructions to convert input arguments to common shapes and types
|
||||
* by adding multi-broadcast and type convert operations. This is a utility function for creating
|
||||
* operations where the shape and type of inputs need to match. It supports both dynamic and
|
||||
* static-shaped arguments.
|
||||
*
|
||||
* @param m containing module for instruction
|
||||
* @param ins insertion location in instruction list
|
||||
* @param inputs instructions to use as argument list; also, the shapes
|
||||
* attached to each instruction_ref are considered for broadcasting
|
||||
* @return std::vector<instruction_ref> a modified argument list
|
||||
*/
|
||||
MIGRAPHX_EXPORT std::vector<instruction_ref> insert_common_args(module& m,
|
||||
instruction_ref ins,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options = {});
|
||||
|
||||
MIGRAPHX_EXPORT
|
||||
std::vector<instruction_ref>
|
||||
add_common_args(module& m, std::vector<instruction_ref> inputs, common_options options = {});
|
||||
|
||||
MIGRAPHX_EXPORT
|
||||
instruction_ref insert_common_op(module& m,
|
||||
instruction_ref ins,
|
||||
const operation& op,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options = {});
|
||||
|
||||
/**
|
||||
* @brief Wrapper for insert_common_args() which inserts operation at the end of the module.
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
instruction_ref add_common_op(module& m,
|
||||
const operation& op,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options = {});
|
||||
|
||||
/**
|
||||
* Calculates the broadcasted shape with the given input_shape and broadcasted dimensions.
|
||||
*
|
||||
* @param input_shape static shape to broadcast
|
||||
* @param bcast_lens dimensions to broadcast to
|
||||
* @return broadcasted shape with calculated strides
|
||||
*/
|
||||
MIGRAPHX_EXPORT
|
||||
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
|
||||
64
docker/rocm/migraphx/include/migraphx/common_dims.hpp
Normal file
64
docker/rocm/migraphx/include/migraphx/common_dims.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/// This will compute a higher dimensional space that will preserve the axes
|
||||
/// for both sets of dimensions. Two axes_maps are provided for each of the
|
||||
/// dims that will map the axis to the axes that are used by the result of
|
||||
/// common_dims.
|
||||
struct MIGRAPHX_EXPORT common_dims
|
||||
{
|
||||
static common_dims compute(const std::vector<std::size_t>& dims1,
|
||||
const std::vector<std::size_t>& dims2);
|
||||
|
||||
/// Map the dimensions into the common higher dimensional space. The
|
||||
/// dimension doesnt need to have the same number of elements as the
|
||||
/// common dimension.
|
||||
std::vector<std::size_t> get_dimensions_for(const std::vector<std::size_t>& idims) const;
|
||||
/// Get the corresponding axes map based on the rank of tensor
|
||||
const std::vector<std::vector<std::size_t>>* get_axes_map(std::size_t n) const;
|
||||
std::vector<std::size_t> dims;
|
||||
std::vector<std::vector<std::size_t>> axes_map1;
|
||||
std::vector<std::vector<std::size_t>> axes_map2;
|
||||
};
|
||||
|
||||
template <class Range>
|
||||
auto elements(const Range& r)
|
||||
{
|
||||
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
|
||||
50
docker/rocm/migraphx/include/migraphx/compile_options.hpp
Normal file
50
docker/rocm/migraphx/include/migraphx/compile_options.hpp
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_COMPILE_OPTIONS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_COMPILE_OPTIONS_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/tracer.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct compile_options
|
||||
{
|
||||
/**
|
||||
* Have MIGX allocate memory for parameters and add instructions
|
||||
* to copy parameters and output to/from an offload device like a GPU.
|
||||
*/
|
||||
bool offload_copy = false;
|
||||
|
||||
bool fast_math = true;
|
||||
bool exhaustive_tune = false;
|
||||
|
||||
tracer trace{};
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
71
docker/rocm/migraphx/include/migraphx/compile_src.hpp
Normal file
71
docker/rocm/migraphx/include/migraphx/compile_src.hpp
Normal file
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/filesystem.hpp>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct src_file
|
||||
{
|
||||
fs::path path;
|
||||
std::string_view content;
|
||||
|
||||
src_file() = default;
|
||||
src_file(fs::path file_path, std::string_view file_content)
|
||||
: path{std::move(file_path)}, content{file_content}
|
||||
{
|
||||
}
|
||||
|
||||
explicit src_file(const std::pair<std::string_view, std::string_view>& pair)
|
||||
: path{pair.first}, content{pair.second}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct MIGRAPHX_EXPORT src_compiler
|
||||
{
|
||||
#ifdef _WIN32
|
||||
fs::path compiler = MIGRAPHX_CXX_COMPILER;
|
||||
#else
|
||||
fs::path compiler = "c++";
|
||||
#endif
|
||||
std::vector<std::string> flags = {};
|
||||
fs::path output = {};
|
||||
fs::path launcher = {};
|
||||
std::string out_ext = ".o";
|
||||
std::function<fs::path(fs::path)> process = nullptr;
|
||||
std::vector<char> compile(const std::vector<src_file>& srcs) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
|
||||
303
docker/rocm/migraphx/include/migraphx/concat_opt.hpp
Normal file
303
docker/rocm/migraphx/include/migraphx/concat_opt.hpp
Normal file
@ -0,0 +1,303 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_CONCAT_OPT_HPP
|
||||
#define MIGRAPHX_GUARD_CONCAT_OPT_HPP
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/op/concat.hpp>
|
||||
#include <migraphx/optional.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef DOXYGEN
|
||||
|
||||
/// An interface for target-dependent optimization for the concat instruction
|
||||
struct concat_optimization
|
||||
{
|
||||
/// A name of the target-dependent allocate operator
|
||||
std::string allocate() const;
|
||||
/// Return the target-independent concat operator
|
||||
optional<op::concat> get_concat(const operation& op) const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
#ifdef TYPE_ERASED_DECLARATION
|
||||
|
||||
// Type-erased interface for:
|
||||
struct MIGRAPHX_EXPORT concat_optimization
|
||||
{
|
||||
//
|
||||
std::string allocate() const;
|
||||
//
|
||||
optional<op::concat> get_concat(const operation& op) const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct concat_optimization
|
||||
{
|
||||
private:
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_pure = typename std::remove_cv<
|
||||
typename std::remove_reference<PrivateDetailTypeErasedT>::type>::type;
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints_impl =
|
||||
decltype(std::declval<PrivateDetailTypeErasedT>().allocate(),
|
||||
std::declval<PrivateDetailTypeErasedT>().get_concat(
|
||||
std::declval<const operation&>()),
|
||||
void());
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints = private_te_constraints_impl<
|
||||
typename private_te_unwrap_reference<private_te_pure<PrivateDetailTypeErasedT>>::type>;
|
||||
|
||||
public:
|
||||
// Constructors
|
||||
concat_optimization() = default;
|
||||
|
||||
template <typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>,
|
||||
concat_optimization>{}>::type>
|
||||
concat_optimization(PrivateDetailTypeErasedT&& value)
|
||||
: private_detail_te_handle_mem_var(
|
||||
std::make_shared<
|
||||
private_detail_te_handle_type<private_te_pure<PrivateDetailTypeErasedT>>>(
|
||||
std::forward<PrivateDetailTypeErasedT>(value)))
|
||||
{
|
||||
}
|
||||
|
||||
// Assignment
|
||||
template <typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>,
|
||||
concat_optimization>{}>::type>
|
||||
concat_optimization& operator=(PrivateDetailTypeErasedT&& value)
|
||||
{
|
||||
using std::swap;
|
||||
auto* derived = this->any_cast<private_te_pure<PrivateDetailTypeErasedT>>();
|
||||
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
|
||||
{
|
||||
*derived = std::forward<PrivateDetailTypeErasedT>(value);
|
||||
}
|
||||
else
|
||||
{
|
||||
concat_optimization rhs(value);
|
||||
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Cast
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
PrivateDetailTypeErasedT* any_cast()
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<const private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
const std::type_info& type_id() const
|
||||
{
|
||||
if(private_detail_te_handle_empty())
|
||||
return typeid(std::nullptr_t);
|
||||
else
|
||||
return private_detail_te_get_handle().type();
|
||||
}
|
||||
|
||||
std::string allocate() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().allocate();
|
||||
}
|
||||
|
||||
optional<op::concat> get_concat(const operation& op) const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().get_concat(op);
|
||||
}
|
||||
|
||||
friend bool is_shared(const concat_optimization& private_detail_x,
|
||||
const concat_optimization& private_detail_y)
|
||||
{
|
||||
return private_detail_x.private_detail_te_handle_mem_var ==
|
||||
private_detail_y.private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_detail_te_handle_base_type
|
||||
{
|
||||
virtual ~private_detail_te_handle_base_type() {}
|
||||
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
|
||||
virtual const std::type_info& type() const = 0;
|
||||
|
||||
virtual std::string allocate() const = 0;
|
||||
virtual optional<op::concat> get_concat(const operation& op) const = 0;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type : private_detail_te_handle_base_type
|
||||
{
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
|
||||
nullptr)
|
||||
: private_detail_te_value(value)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
|
||||
int>::type* = nullptr) noexcept
|
||||
: private_detail_te_value(std::move(value))
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
|
||||
{
|
||||
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
|
||||
}
|
||||
|
||||
const std::type_info& type() const override { return typeid(private_detail_te_value); }
|
||||
|
||||
std::string allocate() const override { return private_detail_te_value.allocate(); }
|
||||
|
||||
optional<op::concat> get_concat(const operation& op) const override
|
||||
{
|
||||
|
||||
return private_detail_te_value.get_concat(op);
|
||||
}
|
||||
|
||||
PrivateDetailTypeErasedT private_detail_te_value;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
|
||||
{
|
||||
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
bool private_detail_te_handle_empty() const
|
||||
{
|
||||
return private_detail_te_handle_mem_var == nullptr;
|
||||
}
|
||||
|
||||
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private_detail_te_handle_base_type& private_detail_te_get_handle()
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
if(private_detail_te_handle_mem_var.use_count() > 1)
|
||||
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
|
||||
};
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType* any_cast(const concat_optimization* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType* any_cast(concat_optimization* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType& any_cast(concat_optimization& x)
|
||||
{
|
||||
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType& any_cast(const concat_optimization& x)
|
||||
{
|
||||
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
58
docker/rocm/migraphx/include/migraphx/config.hpp
Normal file
58
docker/rocm/migraphx/include/migraphx/config.hpp
Normal file
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_CONFIG_HPP
|
||||
#define MIGRAPHX_GUARD_CONFIG_HPP
|
||||
|
||||
#include <migraphx/export.h>
|
||||
#include <ciso646>
|
||||
|
||||
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
|
||||
|
||||
#ifdef BUILD_DEV
|
||||
#define MIGRAPHX_INLINE_NS version_1
|
||||
#else
|
||||
#include <migraphx/version.h>
|
||||
|
||||
#define MIGRAPHX_VERSION_PRIMITIVE_CONCAT(x, y) x##_##y
|
||||
#define MIGRAPHX_VERSION_CONCAT(x, y) MIGRAPHX_VERSION_PRIMITIVE_CONCAT(x, y)
|
||||
|
||||
#define MIGRAPHX_VERSION \
|
||||
MIGRAPHX_VERSION_CONCAT( \
|
||||
MIGRAPHX_VERSION_CONCAT(MIGRAPHX_VERSION_MAJOR, MIGRAPHX_VERSION_MINOR), \
|
||||
MIGRAPHX_VERSION_PATCH)
|
||||
|
||||
#define MIGRAPHX_INLINE_NS MIGRAPHX_VERSION_CONCAT(version, MIGRAPHX_VERSION)
|
||||
#endif // build_dev
|
||||
#endif // clang_tidy
|
||||
|
||||
#ifdef DOXYGEN
|
||||
#define MIGRAPHX_INLINE_NS internal
|
||||
#endif // doxygen
|
||||
|
||||
#ifdef MIGRAPHX_USE_CLANG_TIDY
|
||||
#define MIGRAPHX_TIDY_CONST const
|
||||
#else
|
||||
#define MIGRAPHX_TIDY_CONST
|
||||
#endif // tidy_const
|
||||
#endif // clang_tidy
|
||||
466
docker/rocm/migraphx/include/migraphx/context.hpp
Normal file
466
docker/rocm/migraphx/include/migraphx/context.hpp
Normal file
@ -0,0 +1,466 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_CONTEXT_HPP
|
||||
#define MIGRAPHX_GUARD_CONTEXT_HPP
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/value.hpp>
|
||||
#include <migraphx/any_ptr.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef DOXYGEN
|
||||
|
||||
/// A context is used to store internal data for a `target`. A context is
|
||||
/// constructed by a target during compilation and passed to the operations
|
||||
/// during `eval`.
|
||||
struct context
|
||||
{
|
||||
/// Wait for any tasks in the context to complete
|
||||
void finish() const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <class T>
|
||||
value to_value_context(const T&)
|
||||
{
|
||||
return value{};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void from_value_context(T&, const value&)
|
||||
{
|
||||
}
|
||||
|
||||
template <class T>
|
||||
any_ptr get_queue_context(T&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void wait_for_context(T&, any_ptr)
|
||||
{
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void finish_on_context(T&, any_ptr)
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef TYPE_ERASED_DECLARATION
|
||||
|
||||
// Type-erased interface for:
|
||||
struct MIGRAPHX_EXPORT context
|
||||
{
|
||||
// (optional)
|
||||
value to_value() const;
|
||||
// (optional)
|
||||
void from_value(const value& v);
|
||||
// (optional)
|
||||
any_ptr get_queue();
|
||||
// (optional)
|
||||
void wait_for(any_ptr queue);
|
||||
// (optional)
|
||||
void finish_on(any_ptr queue);
|
||||
//
|
||||
void finish() const;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct context
|
||||
{
|
||||
private:
|
||||
template <class T>
|
||||
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
|
||||
-> decltype(private_detail_te_self.to_value())
|
||||
{
|
||||
return private_detail_te_self.to_value();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
|
||||
{
|
||||
return to_value_context(private_detail_te_self);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static auto
|
||||
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
|
||||
-> decltype(private_detail_te_self.from_value(v))
|
||||
{
|
||||
private_detail_te_self.from_value(v);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void
|
||||
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
|
||||
{
|
||||
from_value_context(private_detail_te_self, v);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
|
||||
-> decltype(private_detail_te_self.get_queue())
|
||||
{
|
||||
return private_detail_te_self.get_queue();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
|
||||
{
|
||||
return get_queue_context(private_detail_te_self);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
|
||||
-> decltype(private_detail_te_self.wait_for(queue))
|
||||
{
|
||||
private_detail_te_self.wait_for(queue);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
|
||||
{
|
||||
wait_for_context(private_detail_te_self, queue);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
|
||||
-> decltype(private_detail_te_self.finish_on(queue))
|
||||
{
|
||||
private_detail_te_self.finish_on(queue);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void
|
||||
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
|
||||
{
|
||||
finish_on_context(private_detail_te_self, queue);
|
||||
}
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
struct private_te_unwrap_reference<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
{
|
||||
using type = PrivateDetailTypeErasedT;
|
||||
};
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_pure = typename std::remove_cv<
|
||||
typename std::remove_reference<PrivateDetailTypeErasedT>::type>::type;
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints_impl =
|
||||
decltype(private_detail_te_default_to_value(char(0),
|
||||
std::declval<PrivateDetailTypeErasedT>()),
|
||||
private_detail_te_default_from_value(char(0),
|
||||
std::declval<PrivateDetailTypeErasedT>(),
|
||||
std::declval<const value&>()),
|
||||
private_detail_te_default_get_queue(char(0),
|
||||
std::declval<PrivateDetailTypeErasedT>()),
|
||||
private_detail_te_default_wait_for(
|
||||
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
|
||||
private_detail_te_default_finish_on(
|
||||
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
|
||||
std::declval<PrivateDetailTypeErasedT>().finish(),
|
||||
void());
|
||||
|
||||
template <class PrivateDetailTypeErasedT>
|
||||
using private_te_constraints = private_te_constraints_impl<
|
||||
typename private_te_unwrap_reference<private_te_pure<PrivateDetailTypeErasedT>>::type>;
|
||||
|
||||
public:
|
||||
// Constructors
|
||||
context() = default;
|
||||
|
||||
template <typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>, context>{}>::type>
|
||||
context(PrivateDetailTypeErasedT&& value)
|
||||
: private_detail_te_handle_mem_var(
|
||||
std::make_shared<
|
||||
private_detail_te_handle_type<private_te_pure<PrivateDetailTypeErasedT>>>(
|
||||
std::forward<PrivateDetailTypeErasedT>(value)))
|
||||
{
|
||||
}
|
||||
|
||||
// Assignment
|
||||
template <typename PrivateDetailTypeErasedT,
|
||||
typename = private_te_constraints<PrivateDetailTypeErasedT>,
|
||||
typename = typename std::enable_if<
|
||||
not std::is_same<private_te_pure<PrivateDetailTypeErasedT>, context>{}>::type>
|
||||
context& operator=(PrivateDetailTypeErasedT&& value)
|
||||
{
|
||||
using std::swap;
|
||||
auto* derived = this->any_cast<private_te_pure<PrivateDetailTypeErasedT>>();
|
||||
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
|
||||
{
|
||||
*derived = std::forward<PrivateDetailTypeErasedT>(value);
|
||||
}
|
||||
else
|
||||
{
|
||||
context rhs(value);
|
||||
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Cast
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
PrivateDetailTypeErasedT* any_cast()
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
|
||||
{
|
||||
return this->type_id() == typeid(PrivateDetailTypeErasedT)
|
||||
? std::addressof(static_cast<const private_detail_te_handle_type<
|
||||
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
|
||||
private_detail_te_get_handle())
|
||||
.private_detail_te_value)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
const std::type_info& type_id() const
|
||||
{
|
||||
if(private_detail_te_handle_empty())
|
||||
return typeid(std::nullptr_t);
|
||||
else
|
||||
return private_detail_te_get_handle().type();
|
||||
}
|
||||
|
||||
value to_value() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().to_value();
|
||||
}
|
||||
|
||||
void from_value(const value& v)
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
(*this).private_detail_te_get_handle().from_value(v);
|
||||
}
|
||||
|
||||
any_ptr get_queue()
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
return (*this).private_detail_te_get_handle().get_queue();
|
||||
}
|
||||
|
||||
void wait_for(any_ptr queue)
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
(*this).private_detail_te_get_handle().wait_for(queue);
|
||||
}
|
||||
|
||||
void finish_on(any_ptr queue)
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
(*this).private_detail_te_get_handle().finish_on(queue);
|
||||
}
|
||||
|
||||
void finish() const
|
||||
{
|
||||
assert((*this).private_detail_te_handle_mem_var);
|
||||
(*this).private_detail_te_get_handle().finish();
|
||||
}
|
||||
|
||||
friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
|
||||
{
|
||||
return private_detail_x.private_detail_te_handle_mem_var ==
|
||||
private_detail_y.private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_detail_te_handle_base_type
|
||||
{
|
||||
virtual ~private_detail_te_handle_base_type() {}
|
||||
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
|
||||
virtual const std::type_info& type() const = 0;
|
||||
|
||||
virtual value to_value() const = 0;
|
||||
virtual void from_value(const value& v) = 0;
|
||||
virtual any_ptr get_queue() = 0;
|
||||
virtual void wait_for(any_ptr queue) = 0;
|
||||
virtual void finish_on(any_ptr queue) = 0;
|
||||
virtual void finish() const = 0;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type : private_detail_te_handle_base_type
|
||||
{
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
|
||||
nullptr)
|
||||
: private_detail_te_value(value)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
|
||||
private_detail_te_handle_type(
|
||||
PrivateDetailTypeErasedT value,
|
||||
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
|
||||
int>::type* = nullptr) noexcept
|
||||
: private_detail_te_value(std::move(value))
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
|
||||
{
|
||||
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
|
||||
}
|
||||
|
||||
const std::type_info& type() const override { return typeid(private_detail_te_value); }
|
||||
|
||||
value to_value() const override
|
||||
{
|
||||
|
||||
return private_detail_te_default_to_value(char(0), private_detail_te_value);
|
||||
}
|
||||
|
||||
void from_value(const value& v) override
|
||||
{
|
||||
|
||||
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
|
||||
}
|
||||
|
||||
any_ptr get_queue() override
|
||||
{
|
||||
|
||||
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
|
||||
}
|
||||
|
||||
void wait_for(any_ptr queue) override
|
||||
{
|
||||
|
||||
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
|
||||
}
|
||||
|
||||
void finish_on(any_ptr queue) override
|
||||
{
|
||||
|
||||
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
|
||||
}
|
||||
|
||||
void finish() const override { private_detail_te_value.finish(); }
|
||||
|
||||
PrivateDetailTypeErasedT private_detail_te_value;
|
||||
};
|
||||
|
||||
template <typename PrivateDetailTypeErasedT>
|
||||
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
|
||||
{
|
||||
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
|
||||
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
bool private_detail_te_handle_empty() const
|
||||
{
|
||||
return private_detail_te_handle_mem_var == nullptr;
|
||||
}
|
||||
|
||||
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
private_detail_te_handle_base_type& private_detail_te_get_handle()
|
||||
{
|
||||
assert(private_detail_te_handle_mem_var != nullptr);
|
||||
if(private_detail_te_handle_mem_var.use_count() > 1)
|
||||
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
|
||||
return *private_detail_te_handle_mem_var;
|
||||
}
|
||||
|
||||
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
|
||||
};
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType* any_cast(const context* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType* any_cast(context* x)
|
||||
{
|
||||
return x->any_cast<ValueType>();
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline ValueType& any_cast(context& x)
|
||||
{
|
||||
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline const ValueType& any_cast(const context& x)
|
||||
{
|
||||
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
|
||||
if(y == nullptr)
|
||||
throw std::bad_cast();
|
||||
return *y;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
|
||||
|
||||
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
38
docker/rocm/migraphx/include/migraphx/convert_to_json.hpp
Normal file
38
docker/rocm/migraphx/include/migraphx/convert_to_json.hpp
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_CONVERT_TO_JSON_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_CONVERT_TO_JSON_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_EXPORT std::string convert_to_json(const std::string& str);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
103
docker/rocm/migraphx/include/migraphx/convolution.hpp
Normal file
103
docker/rocm/migraphx/include/migraphx/convolution.hpp
Normal file
@ -0,0 +1,103 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/dfor.hpp>
|
||||
#include <migraphx/par_for.hpp>
|
||||
#include <migraphx/shape_for_each.hpp>
|
||||
#include <migraphx/tensor_view.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class Output, class T, class Padding, class Stride, class Dilation>
|
||||
void convolution(
|
||||
Output output, T input, T weights, Padding padding, Stride stride, Dilation dilation, int group)
|
||||
{
|
||||
auto output_shape = output.get_shape();
|
||||
auto in_lens = input.get_shape().lens();
|
||||
|
||||
auto wei_lens = weights.get_shape().lens();
|
||||
auto wei_n = wei_lens[0];
|
||||
auto wei_c = wei_lens[1];
|
||||
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
|
||||
|
||||
par_for(output_shape.elements(), [&](auto i) {
|
||||
auto idx_o = output_shape.multi(i);
|
||||
auto w = idx_o[1];
|
||||
auto n_dim = idx_o.size();
|
||||
|
||||
std::vector<std::ptrdiff_t> win_start;
|
||||
for(std::size_t dim = 2; dim < n_dim; ++dim)
|
||||
{
|
||||
auto d_2 = dim - 2;
|
||||
win_start.push_back(std::ptrdiff_t(idx_o[dim] * stride[d_2]) -
|
||||
std::ptrdiff_t(padding[d_2]));
|
||||
}
|
||||
const auto group_id = w / (wei_n / group);
|
||||
|
||||
shape win_shape{output_shape.type(), win_size};
|
||||
|
||||
double acc = 0.0;
|
||||
shape_for_each(win_shape, [&](const auto& idx_win) {
|
||||
auto k = idx_win[0];
|
||||
const auto in_ch = group_id * wei_c + k;
|
||||
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
|
||||
idx[1] = in_ch;
|
||||
std::vector<std::ptrdiff_t> idx_dil(idx_win.size() - 1);
|
||||
std::transform(idx_win.cbegin() + 1,
|
||||
idx_win.cend(),
|
||||
dilation.cbegin(),
|
||||
idx_dil.begin(),
|
||||
[](std::ptrdiff_t ii, std::ptrdiff_t d) { return d * ii; });
|
||||
std::transform(idx_dil.begin(),
|
||||
idx_dil.end(),
|
||||
win_start.begin(),
|
||||
idx.begin() + 2,
|
||||
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
|
||||
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
|
||||
idx_wei[0] = w;
|
||||
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
|
||||
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
|
||||
std::equal(idx.begin(),
|
||||
idx.end(),
|
||||
in_lens.begin(),
|
||||
in_lens.end(),
|
||||
std::less<std::ptrdiff_t>{}))
|
||||
{
|
||||
acc += input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
|
||||
}
|
||||
});
|
||||
|
||||
output[i] = acc;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COPY_ASSIGNABLE_FUNCTION_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_COPY_ASSIGNABLE_FUNCTION_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/optional.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class F>
|
||||
struct copy_assignable_function_wrapper
|
||||
{
|
||||
optional<F> f;
|
||||
|
||||
copy_assignable_function_wrapper(F pf) : f(std::move(pf)) {}
|
||||
copy_assignable_function_wrapper(const copy_assignable_function_wrapper& other) = default;
|
||||
copy_assignable_function_wrapper(copy_assignable_function_wrapper&& other) noexcept = default;
|
||||
copy_assignable_function_wrapper& operator=(copy_assignable_function_wrapper other)
|
||||
{
|
||||
f.reset();
|
||||
if(other.f.has_value())
|
||||
f.emplace(std::move(*other.f));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
auto operator()(Ts&&... xs) const -> decltype((*f)(std::forward<Ts>(xs)...))
|
||||
{
|
||||
return (*f)(std::forward<Ts>(xs)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class F>
|
||||
using copy_assignable_function =
|
||||
std::conditional_t<std::is_copy_assignable<F>{}, F, copy_assignable_function_wrapper<F>>;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_COPY_ASSIGNABLE_FUNCTION_HPP
|
||||
122
docker/rocm/migraphx/include/migraphx/cpp_generator.hpp
Normal file
122
docker/rocm/migraphx/include/migraphx/cpp_generator.hpp
Normal file
@ -0,0 +1,122 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct operation;
|
||||
struct module;
|
||||
struct shape;
|
||||
|
||||
struct cpp_generator_impl;
|
||||
|
||||
struct MIGRAPHX_EXPORT cpp_generator
|
||||
{
|
||||
using generate_module_callback = std::function<std::string(
|
||||
instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>;
|
||||
struct param
|
||||
{
|
||||
std::string name;
|
||||
std::string type;
|
||||
};
|
||||
|
||||
struct MIGRAPHX_EXPORT function
|
||||
{
|
||||
std::vector<param> params = {};
|
||||
std::string body = "";
|
||||
std::string return_type = "void";
|
||||
std::string name = "";
|
||||
std::vector<std::string> attributes = {};
|
||||
std::vector<std::string> tparams = {};
|
||||
function& set_body(const module& m, const generate_module_callback& g);
|
||||
function& set_body(const std::string& s)
|
||||
{
|
||||
body = s;
|
||||
return *this;
|
||||
}
|
||||
function& set_name(const std::string& s)
|
||||
{
|
||||
name = s;
|
||||
return *this;
|
||||
}
|
||||
function& set_attributes(std::vector<std::string> attrs)
|
||||
{
|
||||
attributes = std::move(attrs);
|
||||
return *this;
|
||||
}
|
||||
function& set_types(const module& m);
|
||||
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
|
||||
function& set_generic_types(const module& m);
|
||||
function& add_generic_param(const std::string& pname);
|
||||
function& unused_param(const std::string& pname);
|
||||
};
|
||||
|
||||
cpp_generator();
|
||||
|
||||
// move constructor
|
||||
cpp_generator(cpp_generator&&) noexcept;
|
||||
|
||||
// copy assignment operator
|
||||
cpp_generator& operator=(cpp_generator rhs);
|
||||
|
||||
~cpp_generator() noexcept;
|
||||
|
||||
void fmap(const std::function<std::string(std::string)>& f);
|
||||
|
||||
void fresult(const std::function<std::string(shape)>& f);
|
||||
|
||||
void always_return_tuple(bool b = true);
|
||||
|
||||
void add_point_op(const std::string& op_name, const std::string& code);
|
||||
|
||||
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
|
||||
|
||||
std::string str() const;
|
||||
|
||||
function generate_module(const module& m, const generate_module_callback& g);
|
||||
|
||||
function generate_module(const module& m);
|
||||
|
||||
std::string create_function(const function& f);
|
||||
|
||||
static std::vector<std::string>
|
||||
to_args(const std::vector<instruction_ref>& inputs,
|
||||
const std::unordered_map<instruction_ref, std::string>& names);
|
||||
|
||||
private:
|
||||
std::unique_ptr<cpp_generator_impl> impl;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
|
||||
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
struct program;
|
||||
|
||||
/**
|
||||
* Remove instructions where the output is not used.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT dead_code_elimination
|
||||
{
|
||||
std::string name() const { return "dead_code_elimination"; }
|
||||
void apply(module& m) const;
|
||||
void apply(program& p) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
52
docker/rocm/migraphx/include/migraphx/dfor.hpp
Normal file
52
docker/rocm/migraphx/include/migraphx/dfor.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DFOR_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_DFOR_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
// Multidimensional for loop
|
||||
inline auto dfor()
|
||||
{
|
||||
return [](auto f) { f(); };
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
auto dfor(T x, Ts... xs)
|
||||
{
|
||||
return [=](auto f) {
|
||||
for(T i = 0; i < x; i++)
|
||||
{
|
||||
dfor(xs...)([&](Ts... is) { f(i, is...); });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
50
docker/rocm/migraphx/include/migraphx/dom_info.hpp
Normal file
50
docker/rocm/migraphx/include/migraphx/dom_info.hpp
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
struct MIGRAPHX_EXPORT dominator_info
|
||||
{
|
||||
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2) const;
|
||||
|
||||
std::unordered_map<instruction_ref, instruction_ref> ins2idom;
|
||||
};
|
||||
|
||||
MIGRAPHX_EXPORT dominator_info compute_dominator(const module& m);
|
||||
// MIGRAPHX_EXPORT dominator_info compute_dominator_naive(const module& m);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
75
docker/rocm/migraphx/include/migraphx/dyn_output.hpp
Normal file
75
docker/rocm/migraphx/include/migraphx/dyn_output.hpp
Normal file
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
|
||||
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/argument.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct dyn_output
|
||||
{
|
||||
// original shape from the instruction
|
||||
shape ins_shape;
|
||||
// shape computed at eval time using input arguments
|
||||
shape computed_shape;
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle dynamic and static shape at evaluation time.
|
||||
* If converted to shape type, returns original ins_shape.
|
||||
* If converted to dyn_output type, will compute an output shape using the input arguments.
|
||||
*/
|
||||
template <class F>
|
||||
struct compute_output_shape
|
||||
{
|
||||
F ins_inputs;
|
||||
|
||||
operator dyn_output() const
|
||||
{
|
||||
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
|
||||
if(ins_shape.dynamic())
|
||||
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
|
||||
return dyn_output{ins_shape, ins_shape};
|
||||
});
|
||||
}
|
||||
|
||||
operator shape() const
|
||||
{
|
||||
return ins_inputs(
|
||||
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
|
||||
}
|
||||
};
|
||||
|
||||
template <class F>
|
||||
compute_output_shape<F> make_compute_output_shape(F f)
|
||||
{
|
||||
return {f};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif
|
||||
75
docker/rocm/migraphx/include/migraphx/dynamic_loader.hpp
Normal file
75
docker/rocm/migraphx/include/migraphx/dynamic_loader.hpp
Normal file
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/filesystem.hpp>
|
||||
#include <migraphx/optional.hpp>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct dynamic_loader_impl;
|
||||
|
||||
struct MIGRAPHX_EXPORT dynamic_loader
|
||||
{
|
||||
template <class T>
|
||||
static fs::path path(T* address)
|
||||
{
|
||||
return path(reinterpret_cast<void*>(address));
|
||||
}
|
||||
static fs::path path(void* address);
|
||||
static optional<dynamic_loader> try_load(const fs::path& p);
|
||||
|
||||
dynamic_loader() = default;
|
||||
|
||||
dynamic_loader(const fs::path& p);
|
||||
|
||||
dynamic_loader(const char* image, std::size_t size);
|
||||
|
||||
dynamic_loader(const std::vector<char>& buffer);
|
||||
|
||||
std::shared_ptr<void> get_symbol(const std::string& name) const;
|
||||
|
||||
template <class F>
|
||||
std::function<F> get_function(const std::string& name) const
|
||||
{
|
||||
auto s = get_symbol(name);
|
||||
return [=](auto&&... xs) -> decltype(auto) {
|
||||
auto f = reinterpret_cast<std::add_pointer_t<F>>(s.get());
|
||||
return f(std::forward<decltype(xs)>(xs)...);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<dynamic_loader_impl> impl;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
|
||||
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove memory allocations. This will create a parameter which is the max of all memory used in
|
||||
* the program.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_allocation
|
||||
{
|
||||
std::string allocation_op{};
|
||||
std::size_t alignment = 32;
|
||||
std::string name() const { return "eliminate_allocation"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove identical instructions.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_common_subexpression
|
||||
{
|
||||
std::string name() const { return "eliminate_common_subexpression"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
50
docker/rocm/migraphx/include/migraphx/eliminate_concat.hpp
Normal file
50
docker/rocm/migraphx/include/migraphx/eliminate_concat.hpp
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/concat_opt.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove concat operators by having each operator can write to different chunk of memory.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_concat
|
||||
{
|
||||
concat_optimization concat_opt;
|
||||
std::string name() const { return "eliminate_concat"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove contiguous instructions by checking if the operator can use non-standard shapes.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_contiguous
|
||||
{
|
||||
std::string op_name;
|
||||
std::string name() const { return "eliminate_contiguous"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
48
docker/rocm/migraphx/include/migraphx/eliminate_convert.hpp
Normal file
48
docker/rocm/migraphx/include/migraphx/eliminate_convert.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONVERTS_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONVERTS_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove nested converts and nop converts.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_convert
|
||||
{
|
||||
std::string name() const { return "eliminate_convert"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_DATA_TYPE_HPP
|
||||
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_DATA_TYPE_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove data types. This will instert convert operators so the data type
|
||||
* is not used by any operator.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_data_type
|
||||
{
|
||||
std::set<shape::type_t> unsupported_types;
|
||||
shape::type_t target_type;
|
||||
std::set<std::string> unsupported_ops = {"all"};
|
||||
std::string name() const { return "eliminate_data_type"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
50
docker/rocm/migraphx/include/migraphx/eliminate_identity.hpp
Normal file
50
docker/rocm/migraphx/include/migraphx/eliminate_identity.hpp
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove identity instructions. Currently when used as the last pass, it will
|
||||
* preserve the semantics of previous program state, therefore dead code elimination
|
||||
* should not be used afterwards.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_identity
|
||||
{
|
||||
std::string name() const { return "eliminate_identity"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
52
docker/rocm/migraphx/include/migraphx/eliminate_pad.hpp
Normal file
52
docker/rocm/migraphx/include/migraphx/eliminate_pad.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Remove pads if they can be written as an
|
||||
* attribute to another op (im2col, convolution, pooling)
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT eliminate_pad
|
||||
{
|
||||
std::string name() const { return "eliminate_pad"; }
|
||||
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
80
docker/rocm/migraphx/include/migraphx/env.hpp
Normal file
80
docker/rocm/migraphx/include/migraphx/env.hpp
Normal file
@ -0,0 +1,80 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_ENV_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_ENV_HPP
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
// Declare a cached environment variable
|
||||
#define MIGRAPHX_DECLARE_ENV_VAR(x) \
|
||||
struct x \
|
||||
{ \
|
||||
static const char* value() { return #x; } \
|
||||
}; // NOLINT
|
||||
|
||||
MIGRAPHX_EXPORT bool enabled(const char* name);
|
||||
MIGRAPHX_EXPORT bool disabled(const char* name);
|
||||
MIGRAPHX_EXPORT std::vector<std::string> env(const char* name);
|
||||
|
||||
MIGRAPHX_EXPORT std::size_t value_of(const char* name, std::size_t fallback = 0);
|
||||
|
||||
MIGRAPHX_EXPORT std::string string_value_of(const char* name, std::string fallback = "");
|
||||
|
||||
template <class T>
|
||||
bool enabled(T)
|
||||
{
|
||||
static const bool result = enabled(T::value());
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool disabled(T)
|
||||
{
|
||||
static const bool result = disabled(T::value());
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::size_t value_of(T, std::size_t fallback = 0)
|
||||
{
|
||||
static const std::size_t result = value_of(T::value(), fallback);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::string string_value_of(T, std::string fallback = "")
|
||||
{
|
||||
static const std::string result = string_value_of(T::value(), fallback);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
69
docker/rocm/migraphx/include/migraphx/erase.hpp
Normal file
69
docker/rocm/migraphx/include/migraphx/erase.hpp
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_ERASE_HPP
|
||||
#define MIGRAPHX_GUARD_ERASE_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* @brief Erase all elements from a container
|
||||
*
|
||||
* @param r The container to erase elements from
|
||||
* @param value The value to be erased
|
||||
* @return Returns iterator to erased element
|
||||
*/
|
||||
template <class R, class T>
|
||||
auto erase(R&& r, const T& value)
|
||||
{
|
||||
return r.erase(std::remove(r.begin(), r.end(), value), r.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Erase all elements from a container
|
||||
*
|
||||
* @param r The container to erase elements from
|
||||
* @param pred Predicate function that selects which elements should be erased.
|
||||
*/
|
||||
template <class R, class P>
|
||||
void erase_if(R&& r, P&& pred)
|
||||
{
|
||||
auto first = r.begin();
|
||||
auto last = r.end();
|
||||
while(first != last)
|
||||
{
|
||||
if(pred(*first))
|
||||
first = r.erase(first);
|
||||
else
|
||||
first++;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
86
docker/rocm/migraphx/include/migraphx/errors.hpp
Normal file
86
docker/rocm/migraphx/include/migraphx/errors.hpp
Normal file
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_ERRORS_HPP
|
||||
#define MIGRAPHX_GUARD_ERRORS_HPP
|
||||
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/// Represents exceptions that can be thrown by migraphxlib
|
||||
struct exception : std::runtime_error
|
||||
{
|
||||
unsigned int error;
|
||||
exception(unsigned int e = 0, const std::string& msg = "") : std::runtime_error(msg), error(e)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Create an exception object
|
||||
*
|
||||
* @param context A message that says where the exception occurred
|
||||
* @param message Custom message for the error
|
||||
* @return Exceptions
|
||||
*/
|
||||
inline exception make_exception(const std::string& context, const std::string& message = "")
|
||||
{
|
||||
return {0, context + ": " + message};
|
||||
}
|
||||
|
||||
inline exception
|
||||
make_exception(const std::string& context, unsigned int e, const std::string& message = "")
|
||||
{
|
||||
return {e, context + ": " + message};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a message of a file location
|
||||
*
|
||||
* @param file The filename
|
||||
* @param line The line number
|
||||
*
|
||||
* @return A string that represents the file location
|
||||
*/
|
||||
inline std::string make_source_context(const std::string& file, int line, const std::string& fname)
|
||||
{
|
||||
return file + ":" + std::to_string(line) + ": " + fname;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_MAKE_SOURCE_CTX() migraphx::make_source_context(__FILE__, __LINE__, __func__)
|
||||
|
||||
/**
|
||||
* @brief Throw an exception with context information
|
||||
*/
|
||||
#define MIGRAPHX_THROW(...) throw migraphx::make_exception(MIGRAPHX_MAKE_SOURCE_CTX(), __VA_ARGS__)
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
|
||||
|
||||
#include <migraphx/any_ptr.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct execution_environment
|
||||
{
|
||||
any_ptr queue = any_ptr{};
|
||||
bool async = false;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
|
||||
41
docker/rocm/migraphx/include/migraphx/fallthrough.hpp
Normal file
41
docker/rocm/migraphx/include/migraphx/fallthrough.hpp
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_FALLTHROUGH_HPP
|
||||
#define MIGRAPHX_GUARD_FALLTHROUGH_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef __clang__
|
||||
#define MIGRAPHX_FALLTHROUGH [[clang::fallthrough]]
|
||||
#else
|
||||
#define MIGRAPHX_FALLTHROUGH
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
46
docker/rocm/migraphx/include/migraphx/file_buffer.hpp
Normal file
46
docker/rocm/migraphx/include/migraphx/file_buffer.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FILE_BUFFER_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FILE_BUFFER_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/filesystem.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_EXPORT std::vector<char>
|
||||
read_buffer(const fs::path& filename, size_t offset = 0, size_t nbytes = 0);
|
||||
MIGRAPHX_EXPORT std::string read_string(const fs::path& filename);
|
||||
|
||||
MIGRAPHX_EXPORT void write_string(const fs::path& filename, const std::string& buffer);
|
||||
MIGRAPHX_EXPORT void write_buffer(const fs::path& filename, const char* buffer, std::size_t size);
|
||||
MIGRAPHX_EXPORT void write_buffer(const fs::path& filename, const std::vector<char>& buffer);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
79
docker/rocm/migraphx/include/migraphx/filesystem.hpp
Normal file
79
docker/rocm/migraphx/include/migraphx/filesystem.hpp
Normal file
@ -0,0 +1,79 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
#if defined(CPPCHECK)
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 1
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
|
||||
#elif defined(_WIN32)
|
||||
#if _MSC_VER >= 1920
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 1
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
|
||||
#elif _MSC_VER >= 1900
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 0
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
|
||||
#else
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 0
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
#elif defined(__has_include)
|
||||
#if __has_include(<filesystem>) && __cplusplus >= 201703L
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 1
|
||||
#else
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 0
|
||||
#endif
|
||||
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
|
||||
#else
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
#else
|
||||
#define MIGRAPHX_HAS_FILESYSTEM 0
|
||||
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
|
||||
#if MIGRAPHX_HAS_FILESYSTEM
|
||||
#include <filesystem>
|
||||
#elif MIGRAPHX_HAS_FILESYSTEM_TS
|
||||
#include <experimental/filesystem>
|
||||
#else
|
||||
#error "No filesystem include available"
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#if MIGRAPHX_HAS_FILESYSTEM
|
||||
namespace fs = ::std::filesystem;
|
||||
#elif MIGRAPHX_HAS_FILESYSTEM_TS
|
||||
namespace fs = ::std::experimental::filesystem;
|
||||
#endif
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
47
docker/rocm/migraphx/include/migraphx/fileutils.hpp
Normal file
47
docker/rocm/migraphx/include/migraphx/fileutils.hpp
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_FILEUTILS_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_FILEUTILS_HPP
|
||||
|
||||
#include <migraphx/filesystem.hpp>
|
||||
#include <string_view>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_EXPORT fs::path make_executable_filename(std::string_view name);
|
||||
MIGRAPHX_EXPORT fs::path make_shared_object_filename(std::string_view name);
|
||||
MIGRAPHX_EXPORT fs::path make_object_file_filename(std::string_view name);
|
||||
MIGRAPHX_EXPORT fs::path make_static_library_filename(std::string_view name);
|
||||
MIGRAPHX_EXPORT fs::path append_extension(const fs::path& path, std::string_view ext);
|
||||
|
||||
inline std::string operator+(std::string l, const fs::path& r) { return std::move(l) + r.string(); }
|
||||
|
||||
inline std::string operator+(const fs::path& l, std::string r) { return l.string() + std::move(r); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHLIB_FILEUTILS_HPP
|
||||
452
docker/rocm/migraphx/include/migraphx/float8.hpp
Normal file
452
docker/rocm/migraphx/include/migraphx/float8.hpp
Normal file
@ -0,0 +1,452 @@
|
||||
/* ************************************************************************
|
||||
* Copyright (C) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
|
||||
* ies of the Software, and to permit persons to whom the Software is furnished
|
||||
* to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
|
||||
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
|
||||
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
* ************************************************************************ */
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
|
||||
|
||||
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
|
||||
// shouldn't be used without having enough tests.
|
||||
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <climits>
|
||||
#include <cstring>
|
||||
#include <iosfwd>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/float8_impl.hpp>
|
||||
#include <migraphx/generic_float.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace fp8 {
|
||||
|
||||
enum class rounding_mode
|
||||
{
|
||||
standard, // standard rounding is doing RNE -- round to nearest even
|
||||
stochastic
|
||||
};
|
||||
|
||||
enum class f8_type
|
||||
{
|
||||
bf8 = 0, // s1e5m2
|
||||
fp8 = 1 // s1e4m3
|
||||
};
|
||||
|
||||
template <typename T, bool FNUZ = true>
|
||||
class numeric_limits;
|
||||
|
||||
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
|
||||
struct float8
|
||||
{
|
||||
uint8_t data = 0x00;
|
||||
// default constructor
|
||||
constexpr float8() = default;
|
||||
// default copy constructor
|
||||
constexpr float8(const float8& y) = default;
|
||||
struct from_bits_t
|
||||
{
|
||||
};
|
||||
static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||
|
||||
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
|
||||
|
||||
explicit constexpr float8(
|
||||
float v,
|
||||
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
|
||||
uint32_t rng = 0)
|
||||
{
|
||||
if constexpr(T == migraphx::fp8::f8_type::fp8)
|
||||
{
|
||||
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
|
||||
data = migraphx::fp8::impl::
|
||||
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
|
||||
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
|
||||
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
|
||||
data = migraphx::fp8::impl::
|
||||
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
|
||||
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
|
||||
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
|
||||
data = migraphx::fp8::impl::
|
||||
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
|
||||
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
|
||||
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
|
||||
data = migraphx::fp8::impl::
|
||||
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
|
||||
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
|
||||
#endif // rocblas_F8_downcast_clipping}
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr operator float() const
|
||||
{
|
||||
if constexpr(T == migraphx::fp8::f8_type::fp8)
|
||||
{
|
||||
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
|
||||
} // else
|
||||
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
|
||||
}
|
||||
|
||||
inline explicit constexpr operator bool() const { return not is_zero(); }
|
||||
|
||||
inline constexpr bool is_zero() const
|
||||
{
|
||||
if constexpr(FNUZ)
|
||||
{
|
||||
return data == 0x00;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (data == 0x00) or (data == 0x80);
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr bool is_nan() const
|
||||
{
|
||||
if constexpr(FNUZ)
|
||||
{
|
||||
return data == 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(T == migraphx::fp8::f8_type::bf8)
|
||||
{
|
||||
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
|
||||
(data == 0xFE) or (data == 0xFF);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (data == 0x7F) or (data == 0xFF);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr bool is_inf() const
|
||||
{
|
||||
if constexpr(FNUZ)
|
||||
{
|
||||
return data == 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(T == migraphx::fp8::f8_type::bf8)
|
||||
{
|
||||
return (data == 0x7C) or (data == 0xFC);
|
||||
}
|
||||
else
|
||||
{
|
||||
// no infinities in e4m3fn, represent them as NaNs
|
||||
return (data == 0x7F) or (data == 0xFF);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
|
||||
constexpr float8& operator unary_op(const float8& rhs) \
|
||||
{ \
|
||||
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
|
||||
*this = static_cast<float8>(tmp); \
|
||||
return *this; \
|
||||
} \
|
||||
constexpr float8& operator unary_op(const float& rhs) \
|
||||
{ \
|
||||
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
|
||||
*this = static_cast<float8>(tmp); \
|
||||
return *this; \
|
||||
}
|
||||
|
||||
MIGRAPHX_FP8_UNARY_OP(*=, *)
|
||||
MIGRAPHX_FP8_UNARY_OP(-=, -)
|
||||
MIGRAPHX_FP8_UNARY_OP(+=, +)
|
||||
MIGRAPHX_FP8_UNARY_OP(/=, /)
|
||||
|
||||
inline constexpr float8& operator=(const float8& rhs) = default;
|
||||
inline constexpr float8& operator=(float8&& rhs) noexcept = default;
|
||||
|
||||
inline constexpr float8& operator=(float rhs)
|
||||
{
|
||||
*this = static_cast<float8>(rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline constexpr bool operator==(const float8& rhs) const
|
||||
{
|
||||
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
|
||||
return false;
|
||||
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline constexpr bool operator<(const float8& rhs) const
|
||||
{
|
||||
const auto we = static_cast<float>(*this);
|
||||
const auto them = static_cast<float>(rhs);
|
||||
return we < them;
|
||||
}
|
||||
|
||||
inline constexpr bool operator>(const float8& rhs) const
|
||||
{
|
||||
const auto we = static_cast<float>(*this);
|
||||
const auto them = static_cast<float>(rhs);
|
||||
return we > them;
|
||||
}
|
||||
};
|
||||
|
||||
// https://onnx.ai/onnx/technical/float8.html
|
||||
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
|
||||
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
|
||||
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
|
||||
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
|
||||
/*
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
|
||||
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
|
||||
{ \
|
||||
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
|
||||
}
|
||||
|
||||
// TODO: these should return floats for binary ops
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
|
||||
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
|
||||
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
|
||||
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
|
||||
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
|
||||
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
|
||||
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
|
||||
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
|
||||
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
|
||||
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
|
||||
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
|
||||
|
||||
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
|
||||
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
|
||||
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
|
||||
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
|
||||
*/
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs)
|
||||
{
|
||||
return os << static_cast<float>(rhs);
|
||||
}
|
||||
|
||||
inline fp8e4m3fnuz fabs(fp8e4m3fnuz v)
|
||||
{
|
||||
v.data = v.data & 0x7F; // NOLINT
|
||||
return v;
|
||||
}
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs)
|
||||
{
|
||||
return os << static_cast<float>(rhs);
|
||||
}
|
||||
|
||||
inline fp8e4m3fn fabs(fp8e4m3fn v)
|
||||
{
|
||||
v.data = v.data & 0x7F; // NOLINT
|
||||
return v;
|
||||
}
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs)
|
||||
{
|
||||
return os << static_cast<float>(rhs);
|
||||
}
|
||||
|
||||
inline fp8e5m2fnuz fabs(fp8e5m2fnuz v)
|
||||
{
|
||||
v.data = v.data & 0x7F; // NOLINT
|
||||
return v;
|
||||
}
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs)
|
||||
{
|
||||
return os << static_cast<float>(rhs);
|
||||
}
|
||||
|
||||
inline fp8e5m2 fabs(fp8e5m2 v)
|
||||
{
|
||||
v.data = v.data & 0x7F; // NOLINT
|
||||
return v;
|
||||
}
|
||||
template <>
|
||||
class numeric_limits<fp8e4m3fnuz>
|
||||
{
|
||||
public:
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
|
||||
// NOLINTNEXTLINE
|
||||
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
|
||||
|
||||
static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); }
|
||||
// this is min value that is not DeNorm. DeNorm min is 0x01
|
||||
static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); }
|
||||
|
||||
static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); }
|
||||
};
|
||||
|
||||
template <>
|
||||
class numeric_limits<fp8e4m3fn>
|
||||
{
|
||||
public:
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
|
||||
// NOLINTNEXTLINE
|
||||
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
|
||||
|
||||
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
|
||||
// this is min value that is not DeNorm. DeNorm min is 0x01
|
||||
static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
|
||||
|
||||
static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); }
|
||||
};
|
||||
|
||||
template <>
|
||||
class numeric_limits<fp8e5m2fnuz>
|
||||
{
|
||||
public:
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
|
||||
|
||||
static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT
|
||||
{
|
||||
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
|
||||
}
|
||||
|
||||
static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); }
|
||||
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
|
||||
// this distinction. For the floating points we would end up using lowest most of the times.
|
||||
static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); }
|
||||
|
||||
static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); }
|
||||
};
|
||||
|
||||
template <>
|
||||
class numeric_limits<fp8e5m2>
|
||||
{
|
||||
public:
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
|
||||
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
|
||||
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT
|
||||
|
||||
static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
|
||||
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
|
||||
// this distinction. For the floating points we would end up using lowest most of the times.
|
||||
static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
|
||||
|
||||
static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
|
||||
// 7C and FC both are infinity
|
||||
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
|
||||
};
|
||||
} // namespace fp8
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
// =================================================================================================
|
||||
// define numeric limits for the new data type
|
||||
// NOLINTBEGIN(cert-dcl58-cpp)
|
||||
namespace std {
|
||||
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ>
|
||||
inline bool isfinite(migraphx::fp8::float8<T, FNUZ> x)
|
||||
{
|
||||
return not x.is_inf() and not x.is_nan();
|
||||
}
|
||||
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ>
|
||||
inline bool isnan(migraphx::fp8::float8<T, FNUZ> x)
|
||||
{
|
||||
return x.is_nan();
|
||||
}
|
||||
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ>
|
||||
class numeric_limits<migraphx::fp8::float8<T, FNUZ>>
|
||||
: public migraphx::fp8::numeric_limits<migraphx::fp8::float8<T, FNUZ>>
|
||||
{
|
||||
};
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ, class U>
|
||||
struct common_type<migraphx::fp8::float8<T, FNUZ>, U> : std::common_type<float, U>
|
||||
{
|
||||
};
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ, class U>
|
||||
struct common_type<U, migraphx::fp8::float8<T, FNUZ>> : std::common_type<U, float>
|
||||
{
|
||||
};
|
||||
template <migraphx::fp8::f8_type T, bool FNUZ>
|
||||
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::fp8::float8<T, FNUZ>>
|
||||
{
|
||||
using type = migraphx::fp8::float8<T, FNUZ>;
|
||||
};
|
||||
|
||||
template <migraphx::fp8::f8_type T1, bool FNUZ1, migraphx::fp8::f8_type T2, bool FNUZ2>
|
||||
struct common_type<migraphx::fp8::float8<T1, FNUZ1>, migraphx::fp8::float8<T2, FNUZ2>>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
|
||||
struct common_type<migraphx::generic_float<E, M, F>,
|
||||
migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
|
||||
struct common_type<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>,
|
||||
migraphx::generic_float<E, M, F>>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
|
||||
struct common_type<migraphx::generic_float<E, M, F>, migraphx::fp8::float8<T, FNUZ>>
|
||||
: std::common_type<float, float>
|
||||
{
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
|
||||
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::generic_float<E, M, F>>
|
||||
: std::common_type<float, float>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
// NOLINTEND(cert-dcl58-cpp)
|
||||
// =================================================================================================
|
||||
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
|
||||
328
docker/rocm/migraphx/include/migraphx/float8_impl.hpp
Normal file
328
docker/rocm/migraphx/include/migraphx/float8_impl.hpp
Normal file
@ -0,0 +1,328 @@
|
||||
/* ************************************************************************
|
||||
* Copyright (C) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
|
||||
* ies of the Software, and to permit persons to whom the Software is furnished
|
||||
* to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
|
||||
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
|
||||
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
* ************************************************************************ */
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/bit_cast.hpp>
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace fp8 {
|
||||
namespace impl {
|
||||
|
||||
// NOLINTBEGIN
|
||||
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan, bool Clip>
|
||||
constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
|
||||
{
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
// half is not supported for now
|
||||
constexpr bool is_half = false;
|
||||
static_assert(Wm + We == 7, "Wm+We==7");
|
||||
static_assert(is_float or is_half, "Only float can be cast to f8");
|
||||
|
||||
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
|
||||
|
||||
if constexpr(sizeof(T) == 4)
|
||||
x = migraphx::bit_cast<uint32_t>(f_x);
|
||||
else
|
||||
x = migraphx::bit_cast<uint16_t>(f_x);
|
||||
|
||||
uint32_t head = 0;
|
||||
uint32_t mantissa = 0;
|
||||
int exponent = 0;
|
||||
uint32_t bias = 0;
|
||||
uint32_t sign = 0;
|
||||
if constexpr(sizeof(T) == 4)
|
||||
{
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
}
|
||||
else
|
||||
{
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm);
|
||||
uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1));
|
||||
|
||||
// Calcualte maximum singed value FLT_MAX, FLT_MIN
|
||||
uint32_t signed_max = signed_all_ones;
|
||||
if(not NegativeZeroNan)
|
||||
signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs
|
||||
{
|
||||
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
|
||||
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
|
||||
return 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
// calculate most common NaN mantissa for FP8, which is all Ones in binary
|
||||
uint32_t nan_mantissa = 1;
|
||||
for(auto i = 1; i < Wm; ++i)
|
||||
{
|
||||
nan_mantissa |= (nan_mantissa << 1);
|
||||
}
|
||||
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
|
||||
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
|
||||
{
|
||||
// infinity
|
||||
if(mantissa == 0)
|
||||
{
|
||||
if(sign == 0)
|
||||
return (Wm == 2) ? 0x7B : 0x7E;
|
||||
else
|
||||
return (Wm == 2) ? 0xFB : 0xFE;
|
||||
}
|
||||
else // NaNs
|
||||
return signed_inf + nan_mantissa;
|
||||
}
|
||||
}
|
||||
// handle positive zero
|
||||
if(x == 0)
|
||||
return 0;
|
||||
// handle negative zero
|
||||
else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
|
||||
{
|
||||
return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
|
||||
}
|
||||
|
||||
/* First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
||||
RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
||||
exponent and mantissa again*/
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
|
||||
const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
f8_exponent is the converted f8 exponent with bias encoding
|
||||
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
the difference needs to be adjusted and mantissa shifted*/
|
||||
int act_exponent = 0;
|
||||
int f8_exponent = 0;
|
||||
int exponent_diff = 0;
|
||||
|
||||
if(exponent == 0 and mantissa != 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
||||
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
|
||||
has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some
|
||||
numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is
|
||||
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
|
||||
are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = 1 - bias;
|
||||
exponent_diff = f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= f8_denormal_act_exponent)
|
||||
{
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
||||
For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implict 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff =
|
||||
0; // exponent_diff=0 does not mean there is no difference for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1u << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
// need to know whether the number is right in the middle of two adjacent fp8 numbers. use max
|
||||
// value of 31 to avoid undefined behaviour
|
||||
bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) ==
|
||||
(1u << std::min(31u, mfmt - Wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
|
||||
shift right as shift right could rip off some residual part and make something not midpoint look
|
||||
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
|
||||
midpoint, but after shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= std::min(31u, uint32_t(exponent_diff));
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1u << (mfmt - Wm)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1
|
||||
/*
|
||||
This part is doing rounding by adding mantissa part that is going to get dropped.
|
||||
e.g. if the dropped part for less than 0.5 than it would round down.
|
||||
if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained
|
||||
mantissa.
|
||||
For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and
|
||||
`xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part.
|
||||
For the odd case :
|
||||
this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained
|
||||
part making it RNE.
|
||||
For the even case : this will add xy0:10000000 + 000:01111111 which would
|
||||
round down and keep number Even
|
||||
*/
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(f8_exponent == 0 and ((1 << mfmt) & mantissa))
|
||||
{
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
else if((1 << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - Wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
|
||||
const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2);
|
||||
if(f8_exponent > max_exp)
|
||||
{
|
||||
if(Clip)
|
||||
return signed_max;
|
||||
else
|
||||
{
|
||||
// https://onnx.ai/onnx/technical/float8.html#cast
|
||||
if(NegativeZeroNan)
|
||||
return 0x80;
|
||||
else
|
||||
return (Wm == 2) ? signed_inf : signed_all_ones;
|
||||
}
|
||||
}
|
||||
|
||||
if(f8_exponent == 0 and mantissa == 0)
|
||||
return NegativeZeroNan ? 0 : (sign << 7);
|
||||
mantissa &= (1 << Wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << Wm) | mantissa;
|
||||
}
|
||||
// NOLINTEND
|
||||
|
||||
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan>
|
||||
constexpr T cast_from_f8(uint8_t x)
|
||||
{
|
||||
// half is not supported for now
|
||||
constexpr bool is_half = false;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_float or is_half, "Only float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
// NOLINTNEXTLINE
|
||||
T f_inf, f_neg_inf, f_nan, f_neg0;
|
||||
|
||||
if constexpr(is_float)
|
||||
{
|
||||
const uint32_t if_inf = 0x7F800000;
|
||||
const uint32_t if_neg_inf = 0xFF800000;
|
||||
const uint32_t if_nan = 0x7F800001;
|
||||
const uint32_t if_neg0 = 0x80000000;
|
||||
f_inf = migraphx::bit_cast<float>(if_inf);
|
||||
f_neg_inf = migraphx::bit_cast<float>(if_neg_inf);
|
||||
f_nan = migraphx::bit_cast<float>(if_nan);
|
||||
f_neg0 = migraphx::bit_cast<float>(if_neg0);
|
||||
}
|
||||
|
||||
if(x == 0)
|
||||
return 0;
|
||||
|
||||
uint32_t sign = x >> 7; // NOLINT
|
||||
uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT
|
||||
int exponent = (x & 0x7F) >> Wm; // NOLINT
|
||||
if(NegativeZeroNan)
|
||||
{
|
||||
if(x == 0x80)
|
||||
return f_nan;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == 0x80)
|
||||
return f_neg0;
|
||||
if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT
|
||||
return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan;
|
||||
else if(Wm == 3 and (x == 0x7F or x == 0xFF))
|
||||
return f_nan;
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
|
||||
const int exp_low_cutoff =
|
||||
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + __builtin_clz(mantissa) - (32 - Wm);
|
||||
mantissa <<= sh; // NOLINT
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << Wm) - 1); // NOLINT
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - Wm; // NOLINT
|
||||
|
||||
// subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << wmo; // NOLINT
|
||||
mantissa >>= 1 - exponent; // NOLINT
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if(sizeof(T) == 2)
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT
|
||||
else
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT
|
||||
return migraphx::bit_cast<T>(retval);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace fp8
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
|
||||
69
docker/rocm/migraphx/include/migraphx/float_equal.hpp
Normal file
69
docker/rocm/migraphx/include/migraphx/float_equal.hpp
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include <migraphx/requires.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/type_traits.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class... Ts>
|
||||
using common_type = typename std::common_type<Ts...>::type;
|
||||
|
||||
struct float_equal_fn
|
||||
{
|
||||
template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
|
||||
static bool apply(T x, T y)
|
||||
{
|
||||
return std::isfinite(x) and std::isfinite(y) and
|
||||
std::nextafter(x, std::numeric_limits<T>::lowest()) <= y and
|
||||
std::nextafter(x, std::numeric_limits<T>::max()) >= y;
|
||||
}
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(not is_floating_point<T>{})>
|
||||
static bool apply(T x, T y)
|
||||
{
|
||||
return x == y;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
bool operator()(T x, U y) const
|
||||
{
|
||||
return float_equal_fn::apply<common_type<T, U>>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr float_equal_fn float_equal{};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
49
docker/rocm/migraphx/include/migraphx/fp8_ocp_to_fnuz.hpp
Normal file
49
docker/rocm/migraphx/include/migraphx/fp8_ocp_to_fnuz.hpp
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types
|
||||
* intrinsically. Conversion uses the same bit representation and adjusts scaling factors at the
|
||||
* dequantization. Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the
|
||||
* floating point representation. This pass should run before simplify_qdq so that the scales and
|
||||
* zero points calculated by simplify_qdq have the correct adjusted scaling factors
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz
|
||||
{
|
||||
std::string name() const { return "fp8_ocp_to_fnuz"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
40
docker/rocm/migraphx/include/migraphx/fp8_types.hpp
Normal file
40
docker/rocm/migraphx/include/migraphx/fp8_types.hpp
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
|
||||
#include <migraphx/shape.hpp>
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
struct fp8_types
|
||||
{
|
||||
const std::set<shape::type_t> types = {shape::fp8e4m3fnuz_type,
|
||||
shape::fp8e5m2fnuz_type,
|
||||
shape::fp8e4m3fn_type,
|
||||
shape::fp8e5m2_type};
|
||||
|
||||
std::set<shape::type_t> get() const { return types; }
|
||||
};
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
|
||||
52
docker/rocm/migraphx/include/migraphx/fp_to_double.hpp
Normal file
52
docker/rocm/migraphx/include/migraphx/fp_to_double.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FP_TO_DOUBLE_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FP_TO_DOUBLE_HPP
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* Convert floating point values to double precision.
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT fp_to_double
|
||||
{
|
||||
std::set<shape::type_t> convert_fp_types = {shape::type_t::half_type,
|
||||
shape::type_t::float_type};
|
||||
std::string name() const { return "fp_to_double"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
271
docker/rocm/migraphx/include/migraphx/functional.hpp
Normal file
271
docker/rocm/migraphx/include/migraphx/functional.hpp
Normal file
@ -0,0 +1,271 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_FUNCTIONAL_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_FUNCTIONAL_HPP
|
||||
|
||||
#include <utility>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
// Similiar to decltype(auto) except it will propagate any substitution failures
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_RETURNS(...) \
|
||||
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
|
||||
|
||||
// Lifts an expression into a function object so it can be passed to a higher-order function
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_LIFT(...) \
|
||||
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
|
||||
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <class... Ts>
|
||||
constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
auto tuple_size(const T&)
|
||||
{
|
||||
return typename std::tuple_size<T>::type{};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class R, class F>
|
||||
struct fix_f
|
||||
{
|
||||
F f;
|
||||
|
||||
template <class... Ts>
|
||||
R operator()(Ts&&... xs) const
|
||||
{
|
||||
return f(*this, std::forward<Ts>(xs)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <std::size_t...>
|
||||
struct seq
|
||||
{
|
||||
using type = seq;
|
||||
};
|
||||
|
||||
template <class, class>
|
||||
struct merge_seq;
|
||||
|
||||
template <std::size_t... Xs, std::size_t... Ys>
|
||||
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct gens<0> : seq<>
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct gens<1> : seq<0>
|
||||
{
|
||||
};
|
||||
|
||||
template <class F, std::size_t... Ns>
|
||||
constexpr void repeat_c_impl(F f, seq<Ns...>)
|
||||
{
|
||||
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
|
||||
}
|
||||
|
||||
template <class F, std::size_t... Ns>
|
||||
constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
|
||||
{
|
||||
return f(std::integral_constant<std::size_t, Ns>{}...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <std::size_t N, class F>
|
||||
constexpr void repeat_c(F f)
|
||||
{
|
||||
detail::repeat_c_impl(f, detail::gens<N>{});
|
||||
}
|
||||
|
||||
template <std::size_t N, class F>
|
||||
constexpr auto sequence_c(F&& f)
|
||||
{
|
||||
return detail::sequence_c_impl(f, detail::gens<N>{});
|
||||
}
|
||||
|
||||
template <class IntegerConstant, class F>
|
||||
constexpr auto sequence(IntegerConstant ic, F&& f)
|
||||
{
|
||||
return sequence_c<ic>(f);
|
||||
}
|
||||
|
||||
template <class F, class... Ts>
|
||||
constexpr void each_args(F f, Ts&&... xs)
|
||||
{
|
||||
swallow{(f(std::forward<Ts>(xs)), 0)...};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
constexpr void each_args(F)
|
||||
{
|
||||
}
|
||||
|
||||
template <class F, class T>
|
||||
auto unpack(F f, T&& x)
|
||||
{
|
||||
return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
|
||||
}
|
||||
|
||||
/// Implements a fix-point combinator
|
||||
template <class R, class F>
|
||||
detail::fix_f<R, F> fix(F f)
|
||||
{
|
||||
return {f};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
auto fix(F f)
|
||||
{
|
||||
return fix<void>(f);
|
||||
}
|
||||
|
||||
template <class F, class T>
|
||||
auto fold_impl(F&&, T&& x)
|
||||
{
|
||||
return std::forward<T>(x);
|
||||
}
|
||||
|
||||
template <class F, class T, class U, class... Ts>
|
||||
auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
|
||||
{
|
||||
return fold_impl(f, f(std::forward<T>(x), std::forward<U>(y)), std::forward<Ts>(xs)...);
|
||||
}
|
||||
|
||||
template <class F>
|
||||
auto fold(F f)
|
||||
{
|
||||
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
auto pack(Ts... xs)
|
||||
{
|
||||
return [=](auto f) { return f(xs...); };
|
||||
}
|
||||
|
||||
inline auto pack_join() { return pack(); }
|
||||
|
||||
template <class... Ps>
|
||||
auto pack_join(Ps... ps)
|
||||
{
|
||||
return fold([](auto p1, auto p2) {
|
||||
return p1([=](auto... xs) { return p2([=](auto... ys) { return pack(xs..., ys...); }); });
|
||||
})(ps...);
|
||||
}
|
||||
|
||||
template <class F, class Proj>
|
||||
auto by(F f, Proj proj)
|
||||
{
|
||||
return [=](auto&&... xs) { return f(proj(std::forward<decltype(xs)>(xs))...); };
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto index_of(T& x)
|
||||
{
|
||||
return [&](auto&& y) { return x[y]; };
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
decltype(auto) front_args(T&& x, Ts&&...)
|
||||
{
|
||||
return static_cast<T&&>(x);
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
decltype(auto) back_args(Ts&&... xs)
|
||||
{
|
||||
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
auto pop_front_args(T&&, Ts&&... xs)
|
||||
{
|
||||
return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
auto pop_back_args(Ts&&... xs)
|
||||
{
|
||||
return [&](auto f) {
|
||||
using tuple_type = std::tuple<Ts&&...>;
|
||||
auto t = tuple_type(static_cast<Ts&&>(xs)...);
|
||||
return sequence_c<sizeof...(Ts) - 1>(
|
||||
[&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct always_f
|
||||
{
|
||||
T x;
|
||||
template <class... Ts>
|
||||
constexpr T operator()(Ts&&...) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
auto always(T x)
|
||||
{
|
||||
return always_f<T>{x};
|
||||
}
|
||||
|
||||
struct id
|
||||
{
|
||||
template <class T>
|
||||
constexpr T operator()(T&& x) const
|
||||
{
|
||||
return static_cast<T&&>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Ts>
|
||||
void nop(Ts&&...)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
43
docker/rocm/migraphx/include/migraphx/fuse_concat.hpp
Normal file
43
docker/rocm/migraphx/include/migraphx/fuse_concat.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module_pass_manager;
|
||||
|
||||
struct MIGRAPHX_EXPORT fuse_concat
|
||||
{
|
||||
std::string name() const { return "fuse_concat"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
|
||||
46
docker/rocm/migraphx/include/migraphx/fuse_pointwise.hpp
Normal file
46
docker/rocm/migraphx/include/migraphx/fuse_pointwise.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module_pass_manager;
|
||||
|
||||
struct MIGRAPHX_EXPORT fuse_pointwise
|
||||
{
|
||||
std::string name() const { return "fuse_pointwise"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
|
||||
bool enable_rewrite_reshapes = true;
|
||||
bool enable_rewrite_broadcasts = false;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
|
||||
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module_pass_manager;
|
||||
|
||||
struct MIGRAPHX_EXPORT fuse_pointwise_reduce
|
||||
{
|
||||
std::size_t split_size = 32768;
|
||||
std::string name() const { return "fuse_pointwise_reduce"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_REDUCE_HPP
|
||||
45
docker/rocm/migraphx/include/migraphx/fuse_reduce.hpp
Normal file
45
docker/rocm/migraphx/include/migraphx/fuse_reduce.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module_pass_manager;
|
||||
|
||||
struct MIGRAPHX_EXPORT fuse_reduce
|
||||
{
|
||||
std::string name() const { return "fuse_reduce"; }
|
||||
void apply(module_pass_manager& mpm) const;
|
||||
|
||||
bool enable_rewrite_reshapes = true;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
|
||||
65
docker/rocm/migraphx/include/migraphx/gemm.hpp
Normal file
65
docker/rocm/migraphx/include/migraphx/gemm.hpp
Normal file
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/dfor.hpp>
|
||||
#include <migraphx/par_for.hpp>
|
||||
#include <migraphx/tensor_view.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T, class U, class F>
|
||||
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
|
||||
{
|
||||
std::size_t n_dims = cmat.get_shape().lens().size();
|
||||
std::size_t dim_0 = n_dims - 2;
|
||||
std::size_t dim_1 = n_dims - 1;
|
||||
auto k = amat.get_shape().lens()[dim_1];
|
||||
|
||||
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
|
||||
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
|
||||
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
|
||||
auto cs = cmat.get_shape();
|
||||
|
||||
par_for(cs.elements(), [&](auto i) {
|
||||
auto c_idx = cs.multi(i);
|
||||
auto a_idx = c_idx;
|
||||
auto b_idx = c_idx;
|
||||
double s = 0.0;
|
||||
dfor(k)([&](auto kk) {
|
||||
a_idx[dim_1] = b_idx[dim_0] = kk;
|
||||
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
|
||||
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
|
||||
});
|
||||
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
154
docker/rocm/migraphx/include/migraphx/generate.hpp
Normal file
154
docker/rocm/migraphx/include/migraphx/generate.hpp
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
|
||||
|
||||
#include <migraphx/argument.hpp>
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <migraphx/type_traits.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <random>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
enum class random_mode
|
||||
{
|
||||
legacy,
|
||||
random
|
||||
};
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
|
||||
constexpr T normalize(unsigned long z, random_mode m)
|
||||
{
|
||||
auto max = (m == random_mode::legacy) ? 32 : 1ULL << (sizeof(T) * 8 - 1);
|
||||
const double range = max / 2.0;
|
||||
double result = -1.0 + (z % max) / range;
|
||||
// Expected output: between -1.0 and 1.0
|
||||
return T(result);
|
||||
}
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
|
||||
constexpr T normalize(unsigned long z, random_mode m)
|
||||
{
|
||||
const long long max =
|
||||
(m == random_mode::legacy) ? 1ULL << (sizeof(T) * 5) : 1ULL << (sizeof(T) * 6 - 1);
|
||||
const auto half_max = max / 2;
|
||||
auto result = half_max - (z % max);
|
||||
// Expected output: between -half_max and half_max
|
||||
return T(result);
|
||||
}
|
||||
|
||||
template <class T,
|
||||
MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{} and
|
||||
not std::is_same<T, bool>{})>
|
||||
constexpr T normalize(unsigned long z, random_mode m)
|
||||
{
|
||||
const auto max =
|
||||
(m == random_mode::legacy) ? 1ULL << (sizeof(T) * 5) : 1ULL << (sizeof(T) * 8 - 1);
|
||||
// Expected output: between 0 and max - 1
|
||||
return z % max;
|
||||
}
|
||||
|
||||
template <class T, MIGRAPHX_REQUIRES(std::is_same<T, bool>{})>
|
||||
constexpr bool normalize(unsigned long z, random_mode)
|
||||
{
|
||||
// Expected output: 0 or 1b
|
||||
return static_cast<bool>(z % 2);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct xorshf96_generator
|
||||
{
|
||||
unsigned long x = 123456789;
|
||||
unsigned long y = 362436069;
|
||||
unsigned long z;
|
||||
random_mode mode;
|
||||
|
||||
xorshf96_generator(unsigned long seed, random_mode m) : z(521288629ULL ^ seed), mode(m) {}
|
||||
|
||||
constexpr T operator()() noexcept
|
||||
{
|
||||
x ^= x << 16U;
|
||||
x ^= x >> 5U;
|
||||
x ^= x << 1U;
|
||||
|
||||
unsigned long t = x;
|
||||
x = y;
|
||||
y = z;
|
||||
z = t ^ x ^ y;
|
||||
|
||||
return normalize<T>(z, mode);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct xorshift_generator
|
||||
{
|
||||
unsigned long x;
|
||||
random_mode mode;
|
||||
|
||||
xorshift_generator(unsigned long seed, random_mode m) : x(521288629ULL ^ seed), mode(m) {}
|
||||
|
||||
constexpr T operator()() noexcept
|
||||
{
|
||||
x ^= x >> 12U;
|
||||
x ^= x << 25U;
|
||||
x ^= x >> 27U;
|
||||
return normalize<T>(x * 0x2545F4914F6CDD1D, mode);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
auto generate_tensor_data(const migraphx::shape& s,
|
||||
unsigned long seed,
|
||||
random_mode m = random_mode::legacy)
|
||||
{
|
||||
auto result = make_shared_array<T>(s.element_space());
|
||||
std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed, m});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto fill_tensor_data(const migraphx::shape& s, double value = 0)
|
||||
{
|
||||
auto result = make_shared_array<T>(s.element_space());
|
||||
std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
|
||||
return result;
|
||||
}
|
||||
|
||||
MIGRAPHX_EXPORT argument fill_argument(shape s, double value = 0);
|
||||
|
||||
MIGRAPHX_EXPORT argument generate_argument(shape s,
|
||||
unsigned long seed = 0,
|
||||
random_mode m = random_mode::legacy);
|
||||
|
||||
MIGRAPHX_EXPORT literal generate_literal(shape s, unsigned long seed = 0);
|
||||
|
||||
MIGRAPHX_EXPORT literal abs(literal l);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
476
docker/rocm/migraphx/include/migraphx/generic_float.hpp
Normal file
476
docker/rocm/migraphx/include/migraphx/generic_float.hpp
Normal file
@ -0,0 +1,476 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/bit_cast.hpp>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
#include <cstdint>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <unsigned int N>
|
||||
constexpr unsigned int all_ones() noexcept
|
||||
{
|
||||
return (1u << N) - 1u;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr int countl_zero(T value)
|
||||
{
|
||||
unsigned int r = 0;
|
||||
for(; value != 0u; value >>= 1u)
|
||||
r++;
|
||||
return 8 * sizeof(value) - r;
|
||||
}
|
||||
|
||||
constexpr std::size_t bit_ceil(std::size_t v)
|
||||
{
|
||||
if(v <= 1)
|
||||
return 1;
|
||||
v--;
|
||||
v |= v >> 1u;
|
||||
v |= v >> 2u;
|
||||
v |= v >> 4u;
|
||||
v |= v >> 8u;
|
||||
v |= v >> 16u;
|
||||
v |= v >> 32u;
|
||||
return v + 1;
|
||||
}
|
||||
|
||||
constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
|
||||
{
|
||||
return (x + y - std::size_t{1}) / y;
|
||||
}
|
||||
|
||||
template <unsigned int Bytes>
|
||||
struct unsigned_type
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unsigned_type<1>
|
||||
{
|
||||
using type = std::uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unsigned_type<2>
|
||||
{
|
||||
using type = std::uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unsigned_type<4>
|
||||
{
|
||||
using type = std::uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unsigned_type<8>
|
||||
{
|
||||
using type = std::uint64_t;
|
||||
};
|
||||
|
||||
struct float32_parts
|
||||
{
|
||||
unsigned int mantissa : 23;
|
||||
unsigned int exponent : 8;
|
||||
unsigned int sign : 1;
|
||||
|
||||
static constexpr unsigned int exponent_width() { return 8; }
|
||||
|
||||
static constexpr unsigned int mantissa_width() { return 23; }
|
||||
|
||||
static constexpr unsigned int max_exponent() { return all_ones<8>(); }
|
||||
|
||||
static constexpr int exponent_bias() { return all_ones<7>(); }
|
||||
|
||||
constexpr float to_float() const noexcept { return migraphx::bit_cast<float>(*this); }
|
||||
};
|
||||
|
||||
constexpr float32_parts get_parts(float f) { return migraphx::bit_cast<float32_parts>(f); }
|
||||
|
||||
template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
|
||||
struct __attribute__((packed, may_alias)) generic_float
|
||||
{
|
||||
using type = typename unsigned_type<bit_ceil(
|
||||
integer_divide_ceil(MantissaSize + ExponentSize + 1, 8))>::type;
|
||||
|
||||
type mantissa : MantissaSize;
|
||||
type exponent : ExponentSize;
|
||||
type sign : 1;
|
||||
|
||||
static constexpr int exponent_bias() { return all_ones<ExponentSize - 1>(); }
|
||||
|
||||
explicit constexpr generic_float(float f = 0.0) noexcept { from_float(get_parts(f)); }
|
||||
|
||||
constexpr generic_float& operator=(float f) noexcept
|
||||
{
|
||||
from_float(get_parts(f));
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr generic_float operator-() const noexcept
|
||||
{
|
||||
generic_float result = *this;
|
||||
result.sign = not this->sign;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr generic_float operator+() const noexcept { return *this; }
|
||||
|
||||
constexpr float to_float() const noexcept
|
||||
{
|
||||
float32_parts f{};
|
||||
f.sign = sign;
|
||||
|
||||
if(exponent == 0 and ExponentSize != float32_parts::exponent_width()) // subnormal fps
|
||||
{
|
||||
|
||||
if(mantissa == 0)
|
||||
{
|
||||
f.exponent = 0;
|
||||
f.mantissa = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
type shift = 0;
|
||||
f.mantissa = mantissa;
|
||||
|
||||
if(MantissaSize < float32_parts::mantissa_width())
|
||||
{
|
||||
shift = MantissaSize - ((sizeof(type) * 8) - countl_zero(mantissa));
|
||||
f.mantissa <<= (shift + 1u);
|
||||
}
|
||||
|
||||
f.exponent = float32_parts::exponent_bias() - exponent_bias() - shift;
|
||||
f.mantissa = f.mantissa << (float32_parts::mantissa_width() - MantissaSize);
|
||||
}
|
||||
}
|
||||
else if(exponent == all_ones<ExponentSize>())
|
||||
{
|
||||
f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
|
||||
f.exponent = float32_parts::max_exponent();
|
||||
}
|
||||
else
|
||||
{
|
||||
f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
|
||||
constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
|
||||
f.exponent = int(exponent) + diff;
|
||||
}
|
||||
|
||||
return f.to_float();
|
||||
}
|
||||
|
||||
constexpr void from_float(float32_parts f) noexcept
|
||||
{
|
||||
sign = f.sign;
|
||||
|
||||
if(f.exponent == 0)
|
||||
{
|
||||
exponent = 0;
|
||||
mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);
|
||||
}
|
||||
else if(f.exponent == float32_parts::max_exponent())
|
||||
{
|
||||
exponent = all_ones<ExponentSize>();
|
||||
mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
|
||||
auto e = int(f.exponent) - diff;
|
||||
|
||||
if(e >= static_cast<int>(all_ones<ExponentSize>()))
|
||||
{
|
||||
exponent = all_ones<ExponentSize>();
|
||||
mantissa = 0;
|
||||
}
|
||||
else if(e < 1)
|
||||
{
|
||||
exponent = 0;
|
||||
|
||||
auto shift = diff - int(f.exponent);
|
||||
auto shift_amount = shift + (float32_parts::mantissa_width() - MantissaSize) + 1;
|
||||
|
||||
if(shift_amount < (sizeof(unsigned int) * 8))
|
||||
{
|
||||
mantissa = (f.mantissa | (1u << float32_parts::mantissa_width())) >>
|
||||
(shift + (float32_parts::mantissa_width() - MantissaSize) + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
mantissa = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
exponent = int(f.exponent) - diff;
|
||||
mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);
|
||||
}
|
||||
}
|
||||
|
||||
exponent = std::min<type>(exponent, all_ones<ExponentSize>());
|
||||
}
|
||||
|
||||
constexpr bool is_normal() const noexcept
|
||||
{
|
||||
return exponent != all_ones<ExponentSize>() and exponent != 0;
|
||||
}
|
||||
|
||||
constexpr bool is_inf() const noexcept
|
||||
{
|
||||
return exponent == all_ones<ExponentSize>() and mantissa == 0;
|
||||
}
|
||||
|
||||
constexpr bool is_nan() const noexcept
|
||||
{
|
||||
return exponent == all_ones<ExponentSize>() and mantissa != 0;
|
||||
}
|
||||
|
||||
constexpr bool is_finite() const noexcept { return exponent != all_ones<ExponentSize>(); }
|
||||
|
||||
constexpr operator float() const noexcept { return this->to_float(); }
|
||||
|
||||
static constexpr generic_float infinity()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = all_ones<ExponentSize>();
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float snan()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = all_ones<ExponentSize>();
|
||||
x.mantissa = 1u << (MantissaSize - 2u);
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float qnan()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = all_ones<ExponentSize>();
|
||||
x.mantissa = 1u << (MantissaSize - 1u);
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float min()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = 1;
|
||||
x.mantissa = 0;
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float denorm_min()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = 0;
|
||||
x.mantissa = 1;
|
||||
x.sign = 0;
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float lowest()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = all_ones<ExponentSize>() - 1;
|
||||
x.mantissa = all_ones<MantissaSize>();
|
||||
x.sign = 1;
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float max()
|
||||
{
|
||||
generic_float x{};
|
||||
x.exponent = all_ones<ExponentSize>() - 1;
|
||||
x.mantissa = all_ones<MantissaSize>();
|
||||
x.sign = 0;
|
||||
return x;
|
||||
}
|
||||
|
||||
static constexpr generic_float epsilon()
|
||||
{
|
||||
generic_float x{1.0};
|
||||
x.mantissa++;
|
||||
return generic_float{x.to_float() - 1.0f};
|
||||
}
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
|
||||
constexpr generic_float& operator op(const generic_float & rhs) \
|
||||
{ \
|
||||
float self = *this; \
|
||||
float frhs = rhs; \
|
||||
self op frhs; \
|
||||
*this = generic_float(self); \
|
||||
return *this; \
|
||||
}
|
||||
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
|
||||
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
|
||||
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
|
||||
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
|
||||
friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
|
||||
{ \
|
||||
return generic_float(float(x) op float(y)); \
|
||||
}
|
||||
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
|
||||
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
|
||||
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
|
||||
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \
|
||||
friend constexpr bool operator op(const generic_float& x, const generic_float& y) \
|
||||
{ \
|
||||
return float(x) op float(y); \
|
||||
}
|
||||
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<)
|
||||
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=)
|
||||
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>)
|
||||
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=)
|
||||
|
||||
friend constexpr bool operator==(const generic_float& x, const generic_float& y)
|
||||
{
|
||||
if(not x.is_finite() or not y.is_finite())
|
||||
return false;
|
||||
|
||||
if((x.mantissa == 0 and x.exponent == 0) and (y.mantissa == 0 and y.exponent == 0))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign);
|
||||
}
|
||||
|
||||
friend constexpr bool operator!=(const generic_float& x, const generic_float& y)
|
||||
{
|
||||
return not(x == y);
|
||||
}
|
||||
|
||||
constexpr generic_float& operator++() noexcept
|
||||
{
|
||||
*this += generic_float(1.0f);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type)
|
||||
{
|
||||
generic_float temp = *this;
|
||||
*this += generic_float(1.0f);
|
||||
return temp;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
// NOLINTBEGIN(cert-dcl58-cpp)
|
||||
namespace std {
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F>
|
||||
class numeric_limits<migraphx::generic_float<E, M, F>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr migraphx::generic_float<E, M, F> epsilon()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::epsilon();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> quiet_NaN()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::qnan();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> signaling_NaN()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::snan();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> max()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::max();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> min()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::min();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> lowest()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::lowest();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> infinity()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::infinity();
|
||||
}
|
||||
|
||||
static constexpr migraphx::generic_float<E, M, F> denorm_min()
|
||||
{
|
||||
return migraphx::generic_float<E, M, F>::denorm_min();
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, class T>
|
||||
struct common_type<migraphx::generic_float<E, M, F>, T> : std::common_type<float, T>
|
||||
{
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F, class T>
|
||||
struct common_type<T, migraphx::generic_float<E, M, F>> : std::common_type<float, T>
|
||||
{
|
||||
};
|
||||
|
||||
template <unsigned int E, unsigned int M, unsigned int F>
|
||||
struct common_type<migraphx::generic_float<E, M, F>, migraphx::generic_float<E, M, F>>
|
||||
{
|
||||
using type = migraphx::generic_float<E, M, F>;
|
||||
};
|
||||
|
||||
template <unsigned int E1,
|
||||
unsigned int M1,
|
||||
unsigned int F1,
|
||||
unsigned int E2,
|
||||
unsigned int M2,
|
||||
unsigned int F2>
|
||||
struct common_type<migraphx::generic_float<E1, M1, F1>, migraphx::generic_float<E2, M2, F2>>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
// NOLINTEND(cert-dcl58-cpp)
|
||||
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP
|
||||
51
docker/rocm/migraphx/include/migraphx/half.hpp
Normal file
51
docker/rocm/migraphx/include/migraphx/half.hpp
Normal file
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/float8.hpp>
|
||||
#include <migraphx/generic_float.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using half = migraphx::generic_float<10, 5>;
|
||||
|
||||
namespace detail {
|
||||
template <class T>
|
||||
struct deduce
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <class T>
|
||||
using deduce = typename detail::deduce<T>::type;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
47
docker/rocm/migraphx/include/migraphx/hash.hpp
Normal file
47
docker/rocm/migraphx/include/migraphx/hash.hpp
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <functional>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T>
|
||||
std::size_t hash_value(const T& v)
|
||||
{
|
||||
return std::hash<T>{}(v);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void hash_combine(std::size_t& seed, const T& v)
|
||||
{
|
||||
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
|
||||
45
docker/rocm/migraphx/include/migraphx/inline_module.hpp
Normal file
45
docker/rocm/migraphx/include/migraphx/inline_module.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
|
||||
|
||||
#include <string>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
struct MIGRAPHX_EXPORT inline_module
|
||||
{
|
||||
std::string name() const { return "inline_module"; }
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
53
docker/rocm/migraphx/include/migraphx/insert_pad.hpp
Normal file
53
docker/rocm/migraphx/include/migraphx/insert_pad.hpp
Normal file
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct module;
|
||||
|
||||
/**
|
||||
* insert pads if attribute of padding is asymmetrical
|
||||
*/
|
||||
struct MIGRAPHX_EXPORT insert_pad
|
||||
{
|
||||
std::unordered_set<std::string> ops = {"convolution", "pooling", "im2col"};
|
||||
std::string name() const { return "insert_pad"; }
|
||||
|
||||
void apply(module& m) const;
|
||||
};
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
189
docker/rocm/migraphx/include/migraphx/instruction.hpp
Normal file
189
docker/rocm/migraphx/include/migraphx/instruction.hpp
Normal file
@ -0,0 +1,189 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
|
||||
#define MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
|
||||
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/module_ref.hpp>
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/erase.hpp>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_EXPORT shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
|
||||
MIGRAPHX_EXPORT shape compute_shape(const operation& op,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const std::vector<module_ref>& mods);
|
||||
MIGRAPHX_EXPORT std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
|
||||
MIGRAPHX_EXPORT std::vector<shape> try_compute_shape(const operation& op,
|
||||
const std::vector<shape>& inputs);
|
||||
|
||||
MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end);
|
||||
|
||||
struct MIGRAPHX_EXPORT instruction
|
||||
{
|
||||
instruction() {}
|
||||
|
||||
instruction(operation o, shape r, std::vector<instruction_ref> args);
|
||||
|
||||
instruction(operation o,
|
||||
shape r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> modules);
|
||||
|
||||
instruction(literal l);
|
||||
|
||||
void replace(operation o);
|
||||
|
||||
void recompute_shape();
|
||||
|
||||
void clear_arguments();
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator==(const instruction& i, instruction_ref ref);
|
||||
|
||||
bool valid(instruction_ref start, bool check_order = false) const;
|
||||
|
||||
bool valid() const;
|
||||
|
||||
shape get_shape() const;
|
||||
const literal& get_literal() const;
|
||||
|
||||
const operation& get_operator() const;
|
||||
|
||||
std::string name() const;
|
||||
|
||||
const std::vector<instruction_ref>& inputs() const;
|
||||
|
||||
const std::vector<module_ref>& module_inputs() const;
|
||||
|
||||
/// Where this instruction is used as an input to another instruction
|
||||
const std::vector<instruction_ref>& outputs() const;
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y);
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator!=(const instruction& x, const instruction& y);
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator==(instruction_ref ref, const instruction& i);
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator!=(const instruction& i, instruction_ref ref);
|
||||
|
||||
MIGRAPHX_EXPORT friend bool operator!=(instruction_ref ref, const instruction& i);
|
||||
|
||||
void add_output(instruction_ref ins);
|
||||
|
||||
template <class T>
|
||||
void remove_output(const T& ins)
|
||||
{
|
||||
migraphx::erase(output, ins);
|
||||
}
|
||||
|
||||
static void replace_refs(instruction_ref ins,
|
||||
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
|
||||
const std::unordered_map<module_ref, module_ref>& map_mods);
|
||||
|
||||
static void backreference(instruction_ref ref);
|
||||
|
||||
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
|
||||
|
||||
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
|
||||
|
||||
static void
|
||||
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
|
||||
|
||||
static void replace(instruction_ref ins,
|
||||
operation o,
|
||||
const shape& r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> module_args);
|
||||
|
||||
bool can_eval() const;
|
||||
|
||||
bool is_undefined() const;
|
||||
|
||||
argument eval(bool check_eval = true) const;
|
||||
|
||||
void finalize(context& ctx);
|
||||
|
||||
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
|
||||
|
||||
void set_normalized(bool value = true);
|
||||
bool is_normalized() const;
|
||||
|
||||
bool need_normalization() const;
|
||||
|
||||
operation normalized_operator() const;
|
||||
|
||||
std::size_t get_target_id() const;
|
||||
|
||||
void set_target_id(std::size_t tid);
|
||||
|
||||
void debug_print() const;
|
||||
|
||||
static void print(std::ostream& os,
|
||||
instruction_ref ins,
|
||||
const std::unordered_map<instruction_ref, std::string>& names);
|
||||
|
||||
private:
|
||||
// internal
|
||||
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
|
||||
|
||||
// internal
|
||||
void replace(operation o,
|
||||
const shape& r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> mdl_args);
|
||||
|
||||
// internal
|
||||
void replace(std::vector<instruction_ref> args);
|
||||
|
||||
// internal
|
||||
void replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args);
|
||||
|
||||
// internal
|
||||
void replace_argument(instruction_ref old, instruction_ref new_ins);
|
||||
|
||||
// internal
|
||||
void replace_mod_argument(module_ref old, module_ref new_mod);
|
||||
|
||||
void replace(const shape& r);
|
||||
|
||||
operation op;
|
||||
shape result{};
|
||||
std::vector<instruction_ref> output;
|
||||
std::vector<instruction_ref> arguments;
|
||||
std::vector<module_ref> module_args;
|
||||
literal lit;
|
||||
bool normalized = false;
|
||||
std::size_t target_id = 0;
|
||||
};
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
102
docker/rocm/migraphx/include/migraphx/instruction_ref.hpp
Normal file
102
docker/rocm/migraphx/include/migraphx/instruction_ref.hpp
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_INSTRUCTION_REF_HPP
|
||||
#define MIGRAPHX_GUARD_INSTRUCTION_REF_HPP
|
||||
|
||||
#include <list>
|
||||
#include <functional>
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/requires.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct instruction;
|
||||
#if defined(_WIN32) && !defined(NDEBUG) && !defined(CPPCHECK)
|
||||
struct instruction_ref : std::list<instruction>::iterator
|
||||
{
|
||||
using instruction_iter = std::list<instruction>::iterator;
|
||||
using instruction_const_iter = std::list<instruction>::const_iterator;
|
||||
|
||||
instruction_ref() = default;
|
||||
instruction_ref(const instruction_iter& other) : instruction_iter(other) {}
|
||||
|
||||
template <class T,
|
||||
class U,
|
||||
MIGRAPHX_REQUIRES(std::is_same<T, instruction_ref>{} or
|
||||
std::is_same<U, instruction_ref>{})>
|
||||
friend bool operator==(const T& x, const U& y)
|
||||
{
|
||||
return x._Unwrapped()._Ptr == y._Unwrapped()._Ptr;
|
||||
}
|
||||
|
||||
template <class T,
|
||||
class U,
|
||||
MIGRAPHX_REQUIRES(std::is_same<T, instruction_ref>{} or
|
||||
std::is_same<U, instruction_ref>{})>
|
||||
friend bool operator!=(const T& x, const U& y)
|
||||
{
|
||||
return not(x == y);
|
||||
}
|
||||
};
|
||||
#else
|
||||
using instruction_ref = std::list<instruction>::iterator;
|
||||
#endif
|
||||
|
||||
MIGRAPHX_EXPORT migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<migraphx::instruction_ref> // NOLINT
|
||||
{
|
||||
using argument_type = migraphx::instruction_ref;
|
||||
using result_type = std::size_t;
|
||||
result_type operator()(const migraphx::instruction_ref& x) const noexcept
|
||||
{
|
||||
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equal_to<migraphx::instruction_ref> // NOLINT
|
||||
{
|
||||
using argument_type = migraphx::instruction_ref;
|
||||
using result_type = bool;
|
||||
result_type operator()(const migraphx::instruction_ref& x,
|
||||
const migraphx::instruction_ref& y) const noexcept
|
||||
{
|
||||
return migraphx::as_address(x) == migraphx::as_address(y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <migraphx/instruction.hpp>
|
||||
#endif
|
||||
|
||||
#endif
|
||||
164
docker/rocm/migraphx/include/migraphx/iota_iterator.hpp
Normal file
164
docker/rocm/migraphx/include/migraphx/iota_iterator.hpp
Normal file
@ -0,0 +1,164 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#ifndef MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
|
||||
#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
|
||||
|
||||
#include <migraphx/config.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class F, class Iterator = std::ptrdiff_t>
|
||||
struct basic_iota_iterator
|
||||
{
|
||||
Iterator index;
|
||||
F f;
|
||||
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using reference = decltype(f(std::declval<Iterator>()));
|
||||
using value_type = typename std::remove_reference<reference>::type;
|
||||
using pointer = typename std::add_pointer<value_type>::type;
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
|
||||
basic_iota_iterator& operator+=(int n)
|
||||
{
|
||||
index += n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
basic_iota_iterator& operator-=(int n)
|
||||
{
|
||||
index -= n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
basic_iota_iterator& operator++()
|
||||
{
|
||||
index++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
basic_iota_iterator& operator--()
|
||||
{
|
||||
index--;
|
||||
return *this;
|
||||
}
|
||||
|
||||
basic_iota_iterator operator++(int) // NOLINT
|
||||
{
|
||||
basic_iota_iterator it = *this;
|
||||
index++;
|
||||
return it;
|
||||
}
|
||||
|
||||
basic_iota_iterator operator--(int) // NOLINT
|
||||
{
|
||||
basic_iota_iterator it = *this;
|
||||
index--;
|
||||
return it;
|
||||
}
|
||||
reference operator*() const { return f(index); }
|
||||
pointer operator->() const { return &f(index); }
|
||||
reference operator[](int n) const { return f(index + n); }
|
||||
};
|
||||
|
||||
template <class T, class F>
|
||||
inline basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
|
||||
{
|
||||
return basic_iota_iterator<F, T>{x, f};
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x,
|
||||
std::ptrdiff_t y)
|
||||
{
|
||||
return x += y;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline basic_iota_iterator<F, Iterator> operator+(std::ptrdiff_t x,
|
||||
basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
|
||||
basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index - y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
|
||||
std::ptrdiff_t y)
|
||||
{
|
||||
return x -= y;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index == y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index != y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index < y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index > y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index >= y.index;
|
||||
}
|
||||
|
||||
template <class F, class Iterator>
|
||||
inline bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
|
||||
{
|
||||
return x.index <= y.index;
|
||||
}
|
||||
|
||||
using iota_iterator = basic_iota_iterator<id>;
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user