mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-18 19:16:42 +03:00
88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
import logging
|
|
import os.path
|
|
import re
|
|
import urllib.request
|
|
from typing import Literal
|
|
|
|
import axengine as axe
|
|
|
|
from frigate.const import MODEL_CACHE_DIR
|
|
from frigate.detectors.detection_api import DetectionApi
|
|
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
|
from frigate.util.model import post_process_yolo
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DETECTOR_KEY = "axengine"
|
|
|
|
supported_models = {
|
|
ModelTypeEnum.yologeneric: "frigate-yolov9-.*$",
|
|
}
|
|
|
|
model_cache_dir = os.path.join(MODEL_CACHE_DIR, "axengine_cache/")
|
|
|
|
|
|
class AxengineDetectorConfig(BaseDetectorConfig):
|
|
type: Literal[DETECTOR_KEY]
|
|
|
|
class Axengine(DetectionApi):
|
|
type_key = DETECTOR_KEY
|
|
def __init__(self, config: AxengineDetectorConfig):
|
|
logger.info("__init__ axengine")
|
|
super().__init__(config)
|
|
self.height = config.model.height
|
|
self.width = config.model.width
|
|
model_path = config.model.path or "frigate-yolov9-tiny"
|
|
model_props = self.parse_model_input(model_path)
|
|
self.session = axe.InferenceSession(model_props["path"])
|
|
|
|
def __del__(self):
|
|
pass
|
|
|
|
def parse_model_input(self, model_path):
|
|
model_props = {}
|
|
model_props["preset"] = True
|
|
|
|
model_matched = False
|
|
|
|
for model_type, pattern in supported_models.items():
|
|
if re.match(pattern, model_path):
|
|
model_matched = True
|
|
model_props["model_type"] = model_type
|
|
|
|
if model_matched:
|
|
model_props["filename"] = model_path + ".axmodel"
|
|
model_props["path"] = model_cache_dir + model_props["filename"]
|
|
|
|
if not os.path.isfile(model_props["path"]):
|
|
self.download_model(model_props["filename"])
|
|
else:
|
|
supported_models_str = ", ".join(
|
|
model[1:-1] for model in supported_models
|
|
)
|
|
raise Exception(
|
|
f"Model {model_path} is unsupported. Provide your own model or choose one of the following: {supported_models_str}"
|
|
)
|
|
return model_props
|
|
|
|
def download_model(self, filename):
|
|
if not os.path.isdir(model_cache_dir):
|
|
os.mkdir(model_cache_dir)
|
|
|
|
GITHUB_ENDPOINT = os.environ.get("GITHUB_ENDPOINT", "https://github.com")
|
|
urllib.request.urlretrieve(
|
|
f"{GITHUB_ENDPOINT}/ivanshi1108/assets/releases/download/v0.16.2/{filename}",
|
|
model_cache_dir + filename,
|
|
)
|
|
|
|
def detect_raw(self, tensor_input):
|
|
results = None
|
|
results = self.session.run(None, {"images": tensor_input})
|
|
if self.detector_config.model.model_type == ModelTypeEnum.yologeneric:
|
|
return post_process_yolo(results, self.width, self.height)
|
|
else:
|
|
raise ValueError(
|
|
f'Model type "{self.detector_config.model.model_type}" is currently not supported.'
|
|
)
|
|
|