more robust key check and warning message

This commit is contained in:
Josh Hawkins 2024-10-13 16:14:24 -05:00
parent 5374c18e71
commit afc668fe0a

View File

@ -190,9 +190,11 @@ class GenericONNXEmbedding:
if key in input_names:
onnx_inputs[key].append(value[0])
for key in onnx_inputs.keys():
if onnx_inputs[key]:
for key in input_names:
if onnx_inputs.get(key):
onnx_inputs[key] = np.stack(onnx_inputs[key])
else:
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings]