Formatting

This commit is contained in:
Nicolas Mowen 2025-08-20 14:38:34 -06:00
parent b1c8d494c9
commit ed8af79efa
2 changed files with 68 additions and 47 deletions

View File

@ -95,24 +95,28 @@ class Rknn(DetectionApi):
# user provided models should be a path and contain a "/" # user provided models should be a path and contain a "/"
if "/" in model_path: if "/" in model_path:
model_props["preset"] = False model_props["preset"] = False
# Check if this is an ONNX model or model without extension that needs conversion # Check if this is an ONNX model or model without extension that needs conversion
if model_path.endswith('.onnx') or not os.path.splitext(model_path)[1]: if model_path.endswith(".onnx") or not os.path.splitext(model_path)[1]:
# Try to auto-convert to RKNN format # Try to auto-convert to RKNN format
logger.info(f"Attempting to auto-convert {model_path} to RKNN format...") logger.info(
f"Attempting to auto-convert {model_path} to RKNN format..."
)
# Determine model type from config # Determine model type from config
model_type = self.detector_config.model.model_type model_type = self.detector_config.model.model_type
# Auto-convert the model # Auto-convert the model
converted_path = auto_convert_model(model_path, model_type.value) converted_path = auto_convert_model(model_path, model_type.value)
if converted_path: if converted_path:
model_props["path"] = converted_path model_props["path"] = converted_path
logger.info(f"Successfully converted model to: {converted_path}") logger.info(f"Successfully converted model to: {converted_path}")
else: else:
# Fall back to original path if conversion fails # Fall back to original path if conversion fails
logger.warning(f"Failed to convert {model_path} to RKNN format, using original path") logger.warning(
f"Failed to convert {model_path} to RKNN format, using original path"
)
model_props["path"] = model_path model_props["path"] = model_path
else: else:
model_props["path"] = model_path model_props["path"] = model_path

View File

@ -28,41 +28,56 @@ MODEL_TYPE_CONFIGS = {
}, },
} }
def ensure_torch_dependencies() -> bool: def ensure_torch_dependencies() -> bool:
"""Dynamically install torch dependencies if not available.""" """Dynamically install torch dependencies if not available."""
try: try:
import torch import torch
logger.debug("PyTorch is already available") logger.debug("PyTorch is already available")
return True return True
except ImportError: except ImportError:
logger.info("PyTorch not found, attempting to install...") logger.info("PyTorch not found, attempting to install...")
try: try:
# Try to install torch using pip # Try to install torch using pip
subprocess.check_call([ subprocess.check_call(
sys.executable, "-m", "pip", "install", [
"--break-system-packages", "torch", "torchvision" sys.executable,
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) "-m",
"pip",
"install",
"--break-system-packages",
"torch",
"torchvision",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Verify installation # Verify installation
import torch import torch
logger.info("PyTorch installed successfully") logger.info("PyTorch installed successfully")
return True return True
except (subprocess.CalledProcessError, ImportError) as e: except (subprocess.CalledProcessError, ImportError) as e:
logger.error(f"Failed to install PyTorch: {e}") logger.error(f"Failed to install PyTorch: {e}")
return False return False
def ensure_rknn_toolkit() -> bool: def ensure_rknn_toolkit() -> bool:
"""Ensure RKNN toolkit is available.""" """Ensure RKNN toolkit is available."""
try: try:
import rknn import rknn
from rknn.api import RKNN from rknn.api import RKNN
logger.debug("RKNN toolkit is already available") logger.debug("RKNN toolkit is already available")
return True return True
except ImportError: except ImportError:
logger.error("RKNN toolkit not found. Please ensure it's installed.") logger.error("RKNN toolkit not found. Please ensure it's installed.")
return False return False
def get_soc_type() -> Optional[str]: def get_soc_type() -> Optional[str]:
"""Get the SoC type from device tree.""" """Get the SoC type from device tree."""
try: try:
@ -73,23 +88,24 @@ def get_soc_type() -> Optional[str]:
logger.warning("Could not determine SoC type from device tree") logger.warning("Could not determine SoC type from device tree")
return None return None
def convert_onnx_to_rknn( def convert_onnx_to_rknn(
onnx_path: str, onnx_path: str,
output_path: str, output_path: str,
model_type: str, model_type: str,
quantization: bool = False, quantization: bool = False,
soc: Optional[str] = None soc: Optional[str] = None,
) -> bool: ) -> bool:
""" """
Convert ONNX model to RKNN format. Convert ONNX model to RKNN format.
Args: Args:
onnx_path: Path to input ONNX model onnx_path: Path to input ONNX model
output_path: Path for output RKNN model output_path: Path for output RKNN model
model_type: Type of model (yolo-generic, yolonas, yolox, ssd) model_type: Type of model (yolo-generic, yolonas, yolox, ssd)
quantization: Whether to use 8-bit quantization (i8) or 16-bit float (fp16) quantization: Whether to use 8-bit quantization (i8) or 16-bit float (fp16)
soc: Target SoC platform (auto-detected if None) soc: Target SoC platform (auto-detected if None)
Returns: Returns:
True if conversion successful, False otherwise True if conversion successful, False otherwise
""" """
@ -97,103 +113,104 @@ def convert_onnx_to_rknn(
if not ensure_torch_dependencies(): if not ensure_torch_dependencies():
logger.error("PyTorch dependencies not available") logger.error("PyTorch dependencies not available")
return False return False
if not ensure_rknn_toolkit(): if not ensure_rknn_toolkit():
logger.error("RKNN toolkit not available") logger.error("RKNN toolkit not available")
return False return False
# Get SoC type if not provided # Get SoC type if not provided
if soc is None: if soc is None:
soc = get_soc_type() soc = get_soc_type()
if soc is None: if soc is None:
logger.error("Could not determine SoC type") logger.error("Could not determine SoC type")
return False return False
# Get model config for the specified type # Get model config for the specified type
if model_type not in MODEL_TYPE_CONFIGS: if model_type not in MODEL_TYPE_CONFIGS:
logger.error(f"Unsupported model type: {model_type}") logger.error(f"Unsupported model type: {model_type}")
return False return False
config = MODEL_TYPE_CONFIGS[model_type].copy() config = MODEL_TYPE_CONFIGS[model_type].copy()
config["target_platform"] = soc config["target_platform"] = soc
try: try:
from rknn.api import RKNN from rknn.api import RKNN
logger.info(f"Converting {onnx_path} to RKNN format for {soc}") logger.info(f"Converting {onnx_path} to RKNN format for {soc}")
# Initialize RKNN # Initialize RKNN
rknn = RKNN(verbose=True) rknn = RKNN(verbose=True)
# Configure RKNN # Configure RKNN
rknn.config(**config) rknn.config(**config)
# Load ONNX model # Load ONNX model
if rknn.load_onnx(model=onnx_path) != 0: if rknn.load_onnx(model=onnx_path) != 0:
logger.error("Failed to load ONNX model") logger.error("Failed to load ONNX model")
return False return False
# Build RKNN model # Build RKNN model
if rknn.build(do_quantization=quantization) != 0: if rknn.build(do_quantization=quantization) != 0:
logger.error("Failed to build RKNN model") logger.error("Failed to build RKNN model")
return False return False
# Export RKNN model # Export RKNN model
if rknn.export_rknn(output_path) != 0: if rknn.export_rknn(output_path) != 0:
logger.error("Failed to export RKNN model") logger.error("Failed to export RKNN model")
return False return False
logger.info(f"Successfully converted model to {output_path}") logger.info(f"Successfully converted model to {output_path}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error during RKNN conversion: {e}") logger.error(f"Error during RKNN conversion: {e}")
return False return False
def auto_convert_model( def auto_convert_model(
model_path: str, model_path: str, model_type: str, quantization: bool = False
model_type: str,
quantization: bool = False
) -> Optional[str]: ) -> Optional[str]:
""" """
Automatically convert a model to RKNN format if needed. Automatically convert a model to RKNN format if needed.
Args: Args:
model_path: Path to the model file model_path: Path to the model file
model_type: Type of the model model_type: Type of the model
quantization: Whether to use quantization quantization: Whether to use quantization
Returns: Returns:
Path to the RKNN model if successful, None otherwise Path to the RKNN model if successful, None otherwise
""" """
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR
# Check if model already has .rknn extension # Check if model already has .rknn extension
if model_path.endswith('.rknn'): if model_path.endswith(".rknn"):
return model_path return model_path
# Check if equivalent .rknn file exists # Check if equivalent .rknn file exists
base_path = Path(model_path) base_path = Path(model_path)
if base_path.suffix.lower() in ['.onnx', '']: if base_path.suffix.lower() in [".onnx", ""]:
# Remove extension if present # Remove extension if present
base_name = base_path.stem if base_path.suffix else base_path.name base_name = base_path.stem if base_path.suffix else base_path.name
rknn_path = base_path.parent / f"{base_name}.rknn" rknn_path = base_path.parent / f"{base_name}.rknn"
if rknn_path.exists(): if rknn_path.exists():
logger.info(f"Found existing RKNN model: {rknn_path}") logger.info(f"Found existing RKNN model: {rknn_path}")
return str(rknn_path) return str(rknn_path)
# Convert ONNX to RKNN # Convert ONNX to RKNN
if base_path.suffix.lower() == '.onnx' or not base_path.suffix: if base_path.suffix.lower() == ".onnx" or not base_path.suffix:
logger.info(f"Converting {model_path} to RKNN format...") logger.info(f"Converting {model_path} to RKNN format...")
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
rknn_path.parent.mkdir(parents=True, exist_ok=True) rknn_path.parent.mkdir(parents=True, exist_ok=True)
if convert_onnx_to_rknn(str(base_path), str(rknn_path), model_type, quantization): if convert_onnx_to_rknn(
str(base_path), str(rknn_path), model_type, quantization
):
return str(rknn_path) return str(rknn_path)
else: else:
logger.error(f"Failed to convert {model_path} to RKNN format") logger.error(f"Failed to convert {model_path} to RKNN format")
return None return None
return None return None