applied ruff suggested fixes

This commit is contained in:
Indrek Mandre 2024-02-09 17:06:06 +02:00
parent 128ed4d2bb
commit 8e205597af
3 changed files with 14 additions and 20 deletions

View File

@ -1,16 +1,11 @@
import glob
import logging import logging
import sys
import os
import numpy as np import numpy as np
import ctypes
from pydantic import Field
from typing_extensions import Literal from typing_extensions import Literal
import glob
from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.util import preprocess, yolov8_postprocess from frigate.detectors.util import preprocess, yolov8_postprocess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +24,7 @@ class ONNXDetector(DetectionApi):
try: try:
import onnxruntime import onnxruntime
logger.info(f"ONNX: loaded onnxruntime module") logger.info("ONNX: loaded onnxruntime module")
except ModuleNotFoundError: except ModuleNotFoundError:
logger.error( logger.error(
"ONNX: module loading failed, need 'pip install onnxruntime'?!?" "ONNX: module loading failed, need 'pip install onnxruntime'?!?"

View File

@ -1,17 +1,16 @@
import logging
import sys
import os
import numpy as np
import ctypes import ctypes
import glob
import logging
import os
import subprocess
import sys
import numpy as np
from pydantic import Field from pydantic import Field
from typing_extensions import Literal from typing_extensions import Literal
import glob
import subprocess
from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.util import preprocess, yolov8_postprocess from frigate.detectors.util import preprocess, yolov8_postprocess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,13 +70,13 @@ class ROCmDetector(DetectionApi):
sys.path.append("/opt/rocm/lib") sys.path.append("/opt/rocm/lib")
import migraphx import migraphx
logger.info(f"AMD/ROCm: loaded migraphx module") logger.info("AMD/ROCm: loaded migraphx module")
except ModuleNotFoundError: except ModuleNotFoundError:
logger.error("AMD/ROCm: module loading failed, missing ROCm environment?") logger.error("AMD/ROCm: module loading failed, missing ROCm environment?")
raise raise
if detector_config.conserve_cpu: if detector_config.conserve_cpu:
logger.info(f"AMD/ROCm: switching HIP to blocking mode to conserve CPU") logger.info("AMD/ROCm: switching HIP to blocking mode to conserve CPU")
ctypes.CDLL("/opt/rocm/lib/libamdhip64.so").hipSetDeviceFlags(4) ctypes.CDLL("/opt/rocm/lib/libamdhip64.so").hipSetDeviceFlags(4)
assert ( assert (
detector_config.model.model_type == "yolov8" detector_config.model.model_type == "yolov8"
@ -118,14 +117,14 @@ class ROCmDetector(DetectionApi):
self.model = migraphx.parse_tf(path) self.model = migraphx.parse_tf(path)
else: else:
raise Exception(f"AMD/ROCm: unkown model format {path}") raise Exception(f"AMD/ROCm: unkown model format {path}")
logger.info(f"AMD/ROCm: compiling the model") logger.info("AMD/ROCm: compiling the model")
self.model.compile( self.model.compile(
migraphx.get_target("gpu"), offload_copy=True, fast_math=True migraphx.get_target("gpu"), offload_copy=True, fast_math=True
) )
logger.info(f"AMD/ROCm: saving parsed model into {mxr_path}") logger.info(f"AMD/ROCm: saving parsed model into {mxr_path}")
os.makedirs("/config/model_cache/rocm", exist_ok=True) os.makedirs("/config/model_cache/rocm", exist_ok=True)
migraphx.save(self.model, mxr_path) migraphx.save(self.model, mxr_path)
logger.info(f"AMD/ROCm: model loaded") logger.info("AMD/ROCm: model loaded")
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
model_input_name = self.model.get_parameter_names()[0] model_input_name = self.model.get_parameter_names()[0]

View File

@ -1,7 +1,7 @@
import logging import logging
import numpy as np
import cv2 import cv2
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)