Handle model type inference

This commit is contained in:
Nicolas Mowen 2025-08-20 16:40:28 -06:00
parent 4ac77d1257
commit 2283b32575
2 changed files with 18 additions and 12 deletions

View File

@ -30,7 +30,7 @@ class ONNXModelRunner:
self.type = "ort"
try:
if is_rknn_compatible(model_path):
if device != "CPU" and is_rknn_compatible(model_path):
# Try to auto-convert to RKNN format
rknn_path = auto_convert_model(model_path)
if rknn_path:

View File

@ -38,6 +38,16 @@ MODEL_TYPE_CONFIGS = {
},
}
def get_rknn_model_type(model_path: str) -> str | None:
if all(keyword in model_path for keyword in ["jina-clip-v1", "vision"]):
return "jina-clip-v1-vision"
model_name = os.path.basename(model_path).lower()
if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]):
return model_name
return None
def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
"""
@ -57,17 +67,12 @@ def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
if soc is None:
return False
if not model_type:
model_type = get_rknn_model_type(model_path)
if model_type and model_type in MODEL_TYPE_CONFIGS:
return True
model_name = os.path.basename(model_path).lower()
if any(keyword in model_name for keyword in ["jina", "clip", "vision"]):
return True
if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]):
return True
return False
@ -108,9 +113,7 @@ def ensure_torch_dependencies() -> bool:
def ensure_rknn_toolkit() -> bool:
"""Ensure RKNN toolkit is available."""
try:
import rknn # type: ignore # noqa: F401
from rknn.api import RKNN # type: ignore # noqa: F401
logger.debug("RKNN toolkit is already available")
return True
except ImportError:
@ -379,7 +382,7 @@ def wait_for_conversion_completion(
def auto_convert_model(
model_path: str, model_type: str, quantization: bool = False
model_path: str, model_type: str | None = None, quantization: bool = False
) -> Optional[str]:
"""
Automatically convert a model to RKNN format if needed.
@ -418,6 +421,9 @@ def auto_convert_model(
logger.info(f"Converting {model_path} to RKNN format...")
rknn_path.parent.mkdir(parents=True, exist_ok=True)
if not model_type:
model_type = get_rknn_model_type(base_path)
if convert_onnx_to_rknn(
str(base_path), str(rknn_path), model_type, quantization
):