Enable model caching

This commit is contained in:
Nicolas Mowen 2025-09-27 07:20:33 -06:00
parent aaeab73505
commit cacdf6bb84
2 changed files with 7 additions and 21 deletions

View File

@ -2,7 +2,7 @@ variable "AMDGPU" {
default = "gfx900" default = "gfx900"
} }
variable "ROCM" { variable "ROCM" {
default = "7.0.0" default = "7.0.1"
} }
variable "HSA_OVERRIDE_GFX_VERSION" { variable "HSA_OVERRIDE_GFX_VERSION" {
default = "" default = ""

View File

@ -354,29 +354,15 @@ def get_ort_providers(
} }
) )
elif provider == "MIGraphXExecutionProvider": elif provider == "MIGraphXExecutionProvider":
# Create MIGraphX cache directory
migraphx_cache_dir = os.path.join(MODEL_CACHE_DIR, "migraphx") migraphx_cache_dir = os.path.join(MODEL_CACHE_DIR, "migraphx")
os.makedirs(migraphx_cache_dir, exist_ok=True) os.makedirs(migraphx_cache_dir, exist_ok=True)
if model_path: providers.append(provider)
model_filename = os.path.basename(model_path) options.append(
model_name = os.path.splitext(model_filename)[0] # Remove extension {
compiled_model_path = os.path.join( "migraphx_model_cache_dir": migraphx_cache_dir,
migraphx_cache_dir, f"{model_name}.mxr" }
) )
if os.path.exists(compiled_model_path):
os.environ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL"] = "1"
os.environ["ORT_MIGRAPHX_LOAD_COMPILED_PATH"] = compiled_model_path
else:
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL"] = "1"
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
providers.append(provider)
options.append({})
else:
providers.append(provider)
options.append({})
elif provider == "CPUExecutionProvider": elif provider == "CPUExecutionProvider":
providers.append(provider) providers.append(provider)
options.append( options.append(