mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-25 16:17:41 +03:00
Handle model type inference
This commit is contained in:
parent
4ac77d1257
commit
2283b32575
@ -30,7 +30,7 @@ class ONNXModelRunner:
|
|||||||
self.type = "ort"
|
self.type = "ort"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_rknn_compatible(model_path):
|
if device != "CPU" and is_rknn_compatible(model_path):
|
||||||
# Try to auto-convert to RKNN format
|
# Try to auto-convert to RKNN format
|
||||||
rknn_path = auto_convert_model(model_path)
|
rknn_path = auto_convert_model(model_path)
|
||||||
if rknn_path:
|
if rknn_path:
|
||||||
|
|||||||
@ -38,6 +38,16 @@ MODEL_TYPE_CONFIGS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_rknn_model_type(model_path: str) -> str | None:
|
||||||
|
if all(keyword in model_path for keyword in ["jina-clip-v1", "vision"]):
|
||||||
|
return "jina-clip-v1-vision"
|
||||||
|
|
||||||
|
model_name = os.path.basename(model_path).lower()
|
||||||
|
|
||||||
|
if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]):
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
|
def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -57,17 +67,12 @@ def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
|
|||||||
if soc is None:
|
if soc is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not model_type:
|
||||||
|
model_type = get_rknn_model_type(model_path)
|
||||||
|
|
||||||
if model_type and model_type in MODEL_TYPE_CONFIGS:
|
if model_type and model_type in MODEL_TYPE_CONFIGS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
model_name = os.path.basename(model_path).lower()
|
|
||||||
|
|
||||||
if any(keyword in model_name for keyword in ["jina", "clip", "vision"]):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -108,9 +113,7 @@ def ensure_torch_dependencies() -> bool:
|
|||||||
def ensure_rknn_toolkit() -> bool:
|
def ensure_rknn_toolkit() -> bool:
|
||||||
"""Ensure RKNN toolkit is available."""
|
"""Ensure RKNN toolkit is available."""
|
||||||
try:
|
try:
|
||||||
import rknn # type: ignore # noqa: F401
|
|
||||||
from rknn.api import RKNN # type: ignore # noqa: F401
|
from rknn.api import RKNN # type: ignore # noqa: F401
|
||||||
|
|
||||||
logger.debug("RKNN toolkit is already available")
|
logger.debug("RKNN toolkit is already available")
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -379,7 +382,7 @@ def wait_for_conversion_completion(
|
|||||||
|
|
||||||
|
|
||||||
def auto_convert_model(
|
def auto_convert_model(
|
||||||
model_path: str, model_type: str, quantization: bool = False
|
model_path: str, model_type: str | None = None, quantization: bool = False
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Automatically convert a model to RKNN format if needed.
|
Automatically convert a model to RKNN format if needed.
|
||||||
@ -418,6 +421,9 @@ def auto_convert_model(
|
|||||||
logger.info(f"Converting {model_path} to RKNN format...")
|
logger.info(f"Converting {model_path} to RKNN format...")
|
||||||
rknn_path.parent.mkdir(parents=True, exist_ok=True)
|
rknn_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not model_type:
|
||||||
|
model_type = get_rknn_model_type(base_path)
|
||||||
|
|
||||||
if convert_onnx_to_rknn(
|
if convert_onnx_to_rknn(
|
||||||
str(base_path), str(rknn_path), model_type, quantization
|
str(base_path), str(rknn_path), model_type, quantization
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user