Formatting

This commit is contained in:
Nicolas Mowen 2025-08-20 17:09:59 -06:00
parent e294246afe
commit ba77f066a0
3 changed files with 15 additions and 10 deletions

View File

@ -9,7 +9,7 @@ build-rk: version
docker buildx bake --file=docker/rockchip/rk.hcl rk \ docker buildx bake --file=docker/rockchip/rk.hcl rk \
--set rk.tags=$(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH)-rk --set rk.tags=$(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH)-rk
push-rk: build-rk push-rk: version
docker buildx bake --file=docker/rockchip/rk.hcl rk \ docker buildx bake --file=docker/rockchip/rk.hcl rk \
--set rk.tags=$(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH)-rk \ --set rk.tags=crzynik/frigate:rk \
--push --push

View File

@ -181,7 +181,7 @@ class RKNNModelRunner:
"""Get input names for the model.""" """Get input names for the model."""
# For CLIP models, we need to determine the model type from the path # For CLIP models, we need to determine the model type from the path
model_name = os.path.basename(self.model_path).lower() model_name = os.path.basename(self.model_path).lower()
if "vision" in model_name: if "vision" in model_name:
return ["pixel_values"] return ["pixel_values"]
else: else:
@ -189,7 +189,7 @@ class RKNNModelRunner:
if self.model_type and "jina-clip" in self.model_type: if self.model_type and "jina-clip" in self.model_type:
if "vision" in self.model_type: if "vision" in self.model_type:
return ["pixel_values"] return ["pixel_values"]
# Generic fallback # Generic fallback
return ["input"] return ["input"]
@ -209,7 +209,7 @@ class RKNNModelRunner:
try: try:
input_names = self.get_input_names() input_names = self.get_input_names()
rknn_inputs = [] rknn_inputs = []
for name in input_names: for name in input_names:
if name in inputs: if name in inputs:
if name == "pixel_values": if name == "pixel_values":
@ -224,21 +224,23 @@ class RKNNModelRunner:
rknn_inputs.append(inputs[name]) rknn_inputs.append(inputs[name])
else: else:
logger.warning(f"Input '{name}' not found in inputs, using default") logger.warning(f"Input '{name}' not found in inputs, using default")
if name == "pixel_values": if name == "pixel_values":
batch_size = 1 batch_size = 1
if inputs: if inputs:
for val in inputs.values(): for val in inputs.values():
if hasattr(val, 'shape') and len(val.shape) > 0: if hasattr(val, "shape") and len(val.shape) > 0:
batch_size = val.shape[0] batch_size = val.shape[0]
break break
# Create default in NHWC format as expected by RKNN # Create default in NHWC format as expected by RKNN
rknn_inputs.append(np.zeros((batch_size, 224, 224, 3), dtype=np.float32)) rknn_inputs.append(
np.zeros((batch_size, 224, 224, 3), dtype=np.float32)
)
else: else:
batch_size = 1 batch_size = 1
if inputs: if inputs:
for val in inputs.values(): for val in inputs.values():
if hasattr(val, 'shape') and len(val.shape) > 0: if hasattr(val, "shape") and len(val.shape) > 0:
batch_size = val.shape[0] batch_size = val.shape[0]
break break
rknn_inputs.append(np.zeros((batch_size, 1), dtype=np.float32)) rknn_inputs.append(np.zeros((batch_size, 1), dtype=np.float32))

View File

@ -38,6 +38,7 @@ MODEL_TYPE_CONFIGS = {
}, },
} }
def get_rknn_model_type(model_path: str) -> str | None: def get_rknn_model_type(model_path: str) -> str | None:
if all(keyword in model_path for keyword in ["jina-clip-v1", "vision"]): if all(keyword in model_path for keyword in ["jina-clip-v1", "vision"]):
return "jina-clip-v1-vision" return "jina-clip-v1-vision"
@ -49,6 +50,7 @@ def get_rknn_model_type(model_path: str) -> str | None:
return None return None
def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool: def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
""" """
Check if a model is compatible with RKNN conversion. Check if a model is compatible with RKNN conversion.
@ -111,6 +113,7 @@ def ensure_rknn_toolkit() -> bool:
"""Ensure RKNN toolkit is available.""" """Ensure RKNN toolkit is available."""
try: try:
from rknn.api import RKNN # type: ignore # noqa: F401 from rknn.api import RKNN # type: ignore # noqa: F401
logger.debug("RKNN toolkit is already available") logger.debug("RKNN toolkit is already available")
return True return True
except ImportError as e: except ImportError as e:
@ -438,7 +441,7 @@ def auto_convert_model(
if not model_type: if not model_type:
model_type = get_rknn_model_type(base_path) model_type = get_rknn_model_type(base_path)
if wait_for_conversion_completion(model_type, rknn_path, lock_file_path): if wait_for_conversion_completion(model_type, rknn_path, lock_file_path):
return str(rknn_path) return str(rknn_path)
else: else: