Handle download and loading correctly

This commit is contained in:
Nicolas Mowen 2025-01-03 09:33:58 -07:00
parent 2c182d59f0
commit ba9a00364d
2 changed files with 22 additions and 12 deletions

View File

@ -51,12 +51,14 @@ class ModelDownloader:
download_path: str, download_path: str,
file_names: List[str], file_names: List[str],
download_func: Callable[[str], None], download_func: Callable[[str], None],
complete_func: Callable[[], None] | None = None,
silent: bool = False, silent: bool = False,
): ):
self.model_name = model_name self.model_name = model_name
self.download_path = download_path self.download_path = download_path
self.file_names = file_names self.file_names = file_names
self.download_func = download_func self.download_func = download_func
self.complete_func = complete_func
self.silent = silent self.silent = silent
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
self.download_thread = None self.download_thread = None
@ -97,6 +99,9 @@ class ModelDownloader:
}, },
) )
if self.complete_func:
self.complete_func()
self.requestor.stop() self.requestor.stop()
self.download_complete.set() self.download_complete.set()

View File

@ -169,19 +169,21 @@ class FaceClassificationModel:
self.face_recognizer: cv2.face.LBPHFaceRecognizer = None self.face_recognizer: cv2.face.LBPHFaceRecognizer = None
download_path = os.path.join(MODEL_CACHE_DIR, "facedet") download_path = os.path.join(MODEL_CACHE_DIR, "facedet")
model_files = ["facedet.onnx", "landmarkdet.yaml"] self.model_files = {
self.model_urls = [ "facedet.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx", "landmarkdet.yaml": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml", }
]
if not all(os.path.exists(os.path.join(download_path, n)) for n in model_files): if not all(
logger.debug(f"starting model download for {self.model_name}") os.path.exists(os.path.join(download_path, n))
for n in self.model_files.keys()
):
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name="facedet.onnx", model_name="facedet",
download_path=download_path, download_path=download_path,
file_names=model_files, file_names=self.model_files.keys(),
download_func=self.__download_models, download_func=self.__download_models,
complete_func=self.__build_detector,
) )
self.downloader.ensure_model_files() self.downloader.ensure_model_files()
else: else:
@ -193,9 +195,9 @@ class FaceClassificationModel:
def __download_models(self, path: str) -> None: def __download_models(self, path: str) -> None:
try: try:
file_name = os.path.basename(path) file_name = os.path.basename(path)
ModelDownloader.download_from_url(self.model_urls[file_name], path) ModelDownloader.download_from_url(self.model_files[file_name], path)
except Exception: except Exception as e:
pass logger.error(f"Failed to download {path}: {e}")
def __build_detector(self) -> None: def __build_detector(self) -> None:
self.face_detector = cv2.FaceDetectorYN.create( self.face_detector = cv2.FaceDetectorYN.create(
@ -209,6 +211,9 @@ class FaceClassificationModel:
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml") self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
def __build_classifier(self) -> None: def __build_classifier(self) -> None:
if not self.landmark_detector:
return None
labels = [] labels = []
faces = [] faces = []