Assume rocm model is onnx despite file extension

This commit is contained in:
Nicolas Mowen 2024-10-29 10:53:44 -06:00
parent ee10f00758
commit 94758738b6

View File

@ -98,9 +98,7 @@ class ROCmDetector(DetectionApi):
else: else:
logger.info(f"AMD/ROCm: loading model from {path}") logger.info(f"AMD/ROCm: loading model from {path}")
if path.endswith(".onnx"): if (
self.model = migraphx.parse_onnx(path)
elif (
path.endswith(".tf") path.endswith(".tf")
or path.endswith(".tf2") or path.endswith(".tf2")
or path.endswith(".tflite") or path.endswith(".tflite")
@ -108,7 +106,7 @@ class ROCmDetector(DetectionApi):
# untested # untested
self.model = migraphx.parse_tf(path) self.model = migraphx.parse_tf(path)
else: else:
raise Exception(f"AMD/ROCm: unknown model format {path}") self.model = migraphx.parse_onnx(path)
logger.info("AMD/ROCm: compiling the model") logger.info("AMD/ROCm: compiling the model")