diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 18c577fb0..49b05dd05 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -51,12 +51,14 @@ class ModelDownloader: download_path: str, file_names: List[str], download_func: Callable[[str], None], + complete_func: Callable[[], None] | None = None, silent: bool = False, ): self.model_name = model_name self.download_path = download_path self.file_names = file_names self.download_func = download_func + self.complete_func = complete_func self.silent = silent self.requestor = InterProcessRequestor() self.download_thread = None @@ -97,6 +99,9 @@ class ModelDownloader: }, ) + if self.complete_func: + self.complete_func() + self.requestor.stop() self.download_complete.set() diff --git a/frigate/util/model.py b/frigate/util/model.py index 106815b96..aeffd006a 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -169,19 +169,21 @@ class FaceClassificationModel: self.face_recognizer: cv2.face.LBPHFaceRecognizer = None download_path = os.path.join(MODEL_CACHE_DIR, "facedet") - model_files = ["facedet.onnx", "landmarkdet.yaml"] - self.model_urls = [ - "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx", - "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml", - ] + self.model_files = { + "facedet.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx", + "landmarkdet.yaml": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml", + } - if not all(os.path.exists(os.path.join(download_path, n)) for n in model_files): - logger.debug(f"starting model download for {self.model_name}") + if not all( + os.path.exists(os.path.join(download_path, n)) + for n in self.model_files.keys() + ): self.downloader = ModelDownloader( - model_name="facedet.onnx", + model_name="facedet", download_path=download_path, - file_names=model_files, + file_names=self.model_files.keys(), download_func=self.__download_models, + complete_func=self.__build_detector, ) self.downloader.ensure_model_files() else: @@ -193,9 +195,9 @@ class FaceClassificationModel: def __download_models(self, path: str) -> None: try: file_name = os.path.basename(path) - ModelDownloader.download_from_url(self.model_urls[file_name], path) - except Exception: - pass + ModelDownloader.download_from_url(self.model_files[file_name], path) + except Exception as e: + logger.error(f"Failed to download {path}: {e}") def __build_detector(self) -> None: self.face_detector = cv2.FaceDetectorYN.create( @@ -209,6 +211,9 @@ class FaceClassificationModel: self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml") def __build_classifier(self) -> None: + if not self.landmark_detector: + return None + labels = [] faces = []