Enable model caching

This commit is contained in:
Nicolas Mowen 2025-09-27 07:20:33 -06:00
parent aaeab73505
commit cacdf6bb84
2 changed files with 7 additions and 21 deletions

View File

@ -2,7 +2,7 @@ variable "AMDGPU" {
default = "gfx900"
}
variable "ROCM" {
default = "7.0.0"
default = "7.0.1"
}
variable "HSA_OVERRIDE_GFX_VERSION" {
default = ""

View File

@ -354,29 +354,15 @@ def get_ort_providers(
}
)
elif provider == "MIGraphXExecutionProvider":
# Create MIGraphX cache directory
migraphx_cache_dir = os.path.join(MODEL_CACHE_DIR, "migraphx")
os.makedirs(migraphx_cache_dir, exist_ok=True)
if model_path:
model_filename = os.path.basename(model_path)
model_name = os.path.splitext(model_filename)[0] # Remove extension
compiled_model_path = os.path.join(
migraphx_cache_dir, f"{model_name}.mxr"
)
if os.path.exists(compiled_model_path):
os.environ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL"] = "1"
os.environ["ORT_MIGRAPHX_LOAD_COMPILED_PATH"] = compiled_model_path
else:
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL"] = "1"
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
providers.append(provider)
options.append({})
else:
providers.append(provider)
options.append({})
providers.append(provider)
options.append(
{
"migraphx_model_cache_dir": migraphx_cache_dir,
}
)
elif provider == "CPUExecutionProvider":
providers.append(provider)
options.append(