apply code formatting

This commit is contained in:
MarcA711 2023-11-15 13:48:02 +00:00
parent 61ccf7fdf7
commit 69674cc3e0

View File

@ -27,10 +27,8 @@ DETECTOR_KEY = "rknn"
class RknnDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
yolov8_rknn_model: Literal['n', 's', 'm', 'l', 'x'] = 'n'
core_mask: int = Field(
default=0, ge=0, le=7, title="Core mask for NPU."
)
yolov8_rknn_model: Literal["n", "s", "m", "l", "x"] = "n"
core_mask: int = Field(default=0, ge=0, le=7, title="Core mask for NPU.")
min_score: float = Field(
default=0.5, ge=0, le=1, title="Minimal confidence for detection."
)
@ -43,37 +41,60 @@ class Rknn(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, config: RknnDetectorConfig):
self.model_path = config.model.path or "/models/yolov8n-320x320.rknn"
if config.model.path != None:
self.model_path = config.model.path
else:
if config.yolov8_rknn_model == 'n':
if config.yolov8_rknn_model == "n":
self.model_path = "/models/yolov8n-320x320.rknn"
else:
# check if user mounted /models/download/
if not os.path.isdir("/models/download/"):
logger.error("Make sure to mount the directory \"/models/download/\" to your system. Otherwise the file will be downloaded at every restart.")
raise Exception("Make sure to mount the directory \"/models/download/\" to your system. Otherwise the file will be downloaded at every restart.")
logger.error(
'Make sure to mount the directory "/models/download/" to your system. Otherwise the file will be downloaded at every restart.'
)
raise Exception(
'Make sure to mount the directory "/models/download/" to your system. Otherwise the file will be downloaded at every restart.'
)
self.model_path = "/models/download/yolov8{}-320x320.rknn".format(config.yolov8_rknn_model)
self.model_path = "/models/download/yolov8{}-320x320.rknn".format(
config.yolov8_rknn_model
)
if os.path.isfile(self.model_path) == False:
logger.info("Downloading yolov8{} model.".format(config.yolov8_rknn_model))
urllib.request.urlretrieve("https://github.com/MarcA711/rknn-models/releases/download/latest/yolov8{}-320x320.rknn".format(config.yolov8_rknn_model), self.model_path)
logger.info(
"Downloading yolov8{} model.".format(config.yolov8_rknn_model)
)
urllib.request.urlretrieve(
"https://github.com/MarcA711/rknn-models/releases/download/latest/yolov8{}-320x320.rknn".format(
config.yolov8_rknn_model
),
self.model_path,
)
if (config.model.width != 320) or (config.model.height != 320):
logger.error("Make sure to set the model width and heigth to 320 in your config.yml.")
raise Exception("Make sure to set the model width and heigth to 320 in your config.yml.")
logger.error(
"Make sure to set the model width and heigth to 320 in your config.yml."
)
raise Exception(
"Make sure to set the model width and heigth to 320 in your config.yml."
)
if config.model.input_pixel_format != 'bgr':
logger.error("Make sure to set the model input_pixel_format to \"bgr\" in your config.yml.")
raise Exception("Make sure to set the model input_pixel_format to \"bgr\" in your config.yml.")
if config.model.input_pixel_format != "bgr":
logger.error(
'Make sure to set the model input_pixel_format to "bgr" in your config.yml.'
)
raise Exception(
'Make sure to set the model input_pixel_format to "bgr" in your config.yml.'
)
if config.model.input_tensor != 'nhwc':
logger.error("Make sure to set the model input_tensor to \"nhwc\" in your config.yml.")
raise Exception("Make sure to set the model input_tensor to \"nhwc\" in your config.yml.")
if config.model.input_tensor != "nhwc":
logger.error(
'Make sure to set the model input_tensor to "nhwc" in your config.yml.'
)
raise Exception(
'Make sure to set the model input_tensor to "nhwc" in your config.yml.'
)
self.height = config.model.height
self.width = config.model.width
@ -87,7 +108,9 @@ class Rknn(DetectionApi):
if self.rknn.load_rknn(self.model_path) != 0:
logger.error("Error initializing rknn model.")
if self.rknn.init_runtime(core_mask=self.core_mask) != 0:
logger.error("Error initializing rknn runtime. Do you run docker in privileged mode?")
logger.error(
"Error initializing rknn runtime. Do you run docker in privileged mode?"
)
def __del__(self):
self.rknn.release()