Formatting

This commit is contained in:
Nicolas Mowen 2025-08-20 14:38:34 -06:00
parent b1c8d494c9
commit ed8af79efa
2 changed files with 68 additions and 47 deletions

View File

@ -97,9 +97,11 @@ class Rknn(DetectionApi):
model_props["preset"] = False
# Check if this is an ONNX model or model without extension that needs conversion
if model_path.endswith('.onnx') or not os.path.splitext(model_path)[1]:
if model_path.endswith(".onnx") or not os.path.splitext(model_path)[1]:
# Try to auto-convert to RKNN format
logger.info(f"Attempting to auto-convert {model_path} to RKNN format...")
logger.info(
f"Attempting to auto-convert {model_path} to RKNN format..."
)
# Determine model type from config
model_type = self.detector_config.model.model_type
@ -112,7 +114,9 @@ class Rknn(DetectionApi):
logger.info(f"Successfully converted model to: {converted_path}")
else:
# Fall back to original path if conversion fails
logger.warning(f"Failed to convert {model_path} to RKNN format, using original path")
logger.warning(
f"Failed to convert {model_path} to RKNN format, using original path"
)
model_props["path"] = model_path
else:
model_props["path"] = model_path

View File

@ -28,10 +28,12 @@ MODEL_TYPE_CONFIGS = {
},
}
def ensure_torch_dependencies() -> bool:
"""Dynamically install torch dependencies if not available."""
try:
import torch
logger.debug("PyTorch is already available")
return True
except ImportError:
@ -39,30 +41,43 @@ def ensure_torch_dependencies() -> bool:
try:
# Try to install torch using pip
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"--break-system-packages", "torch", "torchvision"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--break-system-packages",
"torch",
"torchvision",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Verify installation
import torch
logger.info("PyTorch installed successfully")
return True
except (subprocess.CalledProcessError, ImportError) as e:
logger.error(f"Failed to install PyTorch: {e}")
return False
def ensure_rknn_toolkit() -> bool:
"""Ensure RKNN toolkit is available."""
try:
import rknn
from rknn.api import RKNN
logger.debug("RKNN toolkit is already available")
return True
except ImportError:
logger.error("RKNN toolkit not found. Please ensure it's installed.")
return False
def get_soc_type() -> Optional[str]:
"""Get the SoC type from device tree."""
try:
@ -73,12 +88,13 @@ def get_soc_type() -> Optional[str]:
logger.warning("Could not determine SoC type from device tree")
return None
def convert_onnx_to_rknn(
onnx_path: str,
output_path: str,
model_type: str,
quantization: bool = False,
soc: Optional[str] = None
soc: Optional[str] = None,
) -> bool:
"""
Convert ONNX model to RKNN format.
@ -150,10 +166,9 @@ def convert_onnx_to_rknn(
logger.error(f"Error during RKNN conversion: {e}")
return False
def auto_convert_model(
model_path: str,
model_type: str,
quantization: bool = False
model_path: str, model_type: str, quantization: bool = False
) -> Optional[str]:
"""
Automatically convert a model to RKNN format if needed.
@ -169,12 +184,12 @@ def auto_convert_model(
from frigate.const import MODEL_CACHE_DIR
# Check if model already has .rknn extension
if model_path.endswith('.rknn'):
if model_path.endswith(".rknn"):
return model_path
# Check if equivalent .rknn file exists
base_path = Path(model_path)
if base_path.suffix.lower() in ['.onnx', '']:
if base_path.suffix.lower() in [".onnx", ""]:
# Remove extension if present
base_name = base_path.stem if base_path.suffix else base_path.name
rknn_path = base_path.parent / f"{base_name}.rknn"
@ -184,13 +199,15 @@ def auto_convert_model(
return str(rknn_path)
# Convert ONNX to RKNN
if base_path.suffix.lower() == '.onnx' or not base_path.suffix:
if base_path.suffix.lower() == ".onnx" or not base_path.suffix:
logger.info(f"Converting {model_path} to RKNN format...")
# Create output directory if it doesn't exist
rknn_path.parent.mkdir(parents=True, exist_ok=True)
if convert_onnx_to_rknn(str(base_path), str(rknn_path), model_type, quantization):
if convert_onnx_to_rknn(
str(base_path), str(rknn_path), model_type, quantization
):
return str(rknn_path)
else:
logger.error(f"Failed to convert {model_path} to RKNN format")