"Update parse_model_input function to include config information; add a condition to check for specified model_type when using custom model path."

This commit is contained in:
eugene_yao 2024-06-17 14:27:21 -04:00
parent 5b60785cca
commit de3415f08d

View File

@ -36,7 +36,7 @@ class Rknn(DetectionApi):
core_mask = 2**config.num_cores - 1
soc = self.get_soc()
model_props = self.parse_model_input(config.model.path, soc)
model_props = self.parse_model_input(config, soc)
if model_props["preset"]:
config.model.model_type = model_props["model_type"]
@ -75,7 +75,9 @@ class Rknn(DetectionApi):
return soc
def parse_model_input(self, model_path, soc):
def parse_model_input(self, config, soc):
model_path = config.model.path
model_props = {}
# find out if user provides his own model
@ -83,6 +85,12 @@ class Rknn(DetectionApi):
if "/" in model_path:
model_props["preset"] = False
model_props["path"] = model_path
if config.model.model_type:
model_props["model_type"] = config.model.model_type
else:
raise Exception(
"You must specify model_type if specifying your own model file."
)
else:
model_props["preset"] = True