Use fp16 if required

This commit is contained in:
Nicolas Mowen 2024-10-09 16:04:20 -06:00
parent 24a1b16bc0
commit ab8a2caba2
2 changed files with 17 additions and 10 deletions

View File

@ -50,7 +50,9 @@ class GenericONNXEmbedding:
self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision'
self.providers, self.provider_options = get_ort_providers(force_cpu=force_cpu)
self.providers, self.provider_options = get_ort_providers(
force_cpu=force_cpu, requires_fp16=True
)
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None

View File

@ -6,7 +6,7 @@ import onnxruntime as ort
def get_ort_providers(
force_cpu: bool = False, openvino_device: str = "AUTO"
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
) -> tuple[list[str], list[dict[str, any]]]:
if force_cpu:
return (["CPUExecutionProvider"], [{}])
@ -17,14 +17,19 @@ def get_ort_providers(
for provider in providers:
if provider == "TensorrtExecutionProvider":
os.makedirs("/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True)
if not requires_fp16 or os.environ.get("USE_FP_16", "True") != "False":
options.append(
{
"trt_fp16_enable": requires_fp16,
"trt_timing_cache_enable": True,
"trt_engine_cache_enable": True,
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
}
)
else:
options.append({})
elif provider == "OpenVINOExecutionProvider":
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
options.append(