frigate/frigate/detectors/plugins/axengine.py
2025-11-24 02:42:04 +00:00

88 lines
2.8 KiB
Python

import logging
import os.path
import re
import urllib.request
from typing import Literal
import axengine as axe
from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
from frigate.util.model import post_process_yolo
logger = logging.getLogger(__name__)
DETECTOR_KEY = "axengine"
supported_models = {
ModelTypeEnum.yologeneric: "frigate-yolov9-.*$",
}
model_cache_dir = os.path.join(MODEL_CACHE_DIR, "axengine_cache/")
class AxengineDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
class Axengine(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, config: AxengineDetectorConfig):
logger.info("__init__ axengine")
super().__init__(config)
self.height = config.model.height
self.width = config.model.width
model_path = config.model.path or "frigate-yolov9-tiny"
model_props = self.parse_model_input(model_path)
self.session = axe.InferenceSession(model_props["path"])
def __del__(self):
pass
def parse_model_input(self, model_path):
model_props = {}
model_props["preset"] = True
model_matched = False
for model_type, pattern in supported_models.items():
if re.match(pattern, model_path):
model_matched = True
model_props["model_type"] = model_type
if model_matched:
model_props["filename"] = model_path + ".axmodel"
model_props["path"] = model_cache_dir + model_props["filename"]
if not os.path.isfile(model_props["path"]):
self.download_model(model_props["filename"])
else:
supported_models_str = ", ".join(
model[1:-1] for model in supported_models
)
raise Exception(
f"Model {model_path} is unsupported. Provide your own model or choose one of the following: {supported_models_str}"
)
return model_props
def download_model(self, filename):
if not os.path.isdir(model_cache_dir):
os.mkdir(model_cache_dir)
GITHUB_ENDPOINT = os.environ.get("GITHUB_ENDPOINT", "https://github.com")
urllib.request.urlretrieve(
f"{GITHUB_ENDPOINT}/ivanshi1108/assets/releases/download/v0.16.2/{filename}",
model_cache_dir + filename,
)
def detect_raw(self, tensor_input):
results = None
results = self.session.run(None, {"images": tensor_input})
if self.detector_config.model.model_type == ModelTypeEnum.yologeneric:
return post_process_yolo(results, self.width, self.height)
else:
raise ValueError(
f'Model type "{self.detector_config.model.model_type}" is currently not supported.'
)