Handle download and loading correctly

This commit is contained in:
Nicolas Mowen 2025-01-03 09:33:58 -07:00
parent 2c182d59f0
commit ba9a00364d
2 changed files with 22 additions and 12 deletions

View File

@ -51,12 +51,14 @@ class ModelDownloader:
download_path: str,
file_names: List[str],
download_func: Callable[[str], None],
complete_func: Callable[[], None] | None = None,
silent: bool = False,
):
self.model_name = model_name
self.download_path = download_path
self.file_names = file_names
self.download_func = download_func
self.complete_func = complete_func
self.silent = silent
self.requestor = InterProcessRequestor()
self.download_thread = None
@ -97,6 +99,9 @@ class ModelDownloader:
},
)
if self.complete_func:
self.complete_func()
self.requestor.stop()
self.download_complete.set()

View File

@ -169,19 +169,21 @@ class FaceClassificationModel:
self.face_recognizer: cv2.face.LBPHFaceRecognizer = None
download_path = os.path.join(MODEL_CACHE_DIR, "facedet")
model_files = ["facedet.onnx", "landmarkdet.yaml"]
self.model_urls = [
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
]
self.model_files = {
"facedet.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
"landmarkdet.yaml": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
}
if not all(os.path.exists(os.path.join(download_path, n)) for n in model_files):
logger.debug(f"starting model download for {self.model_name}")
if not all(
os.path.exists(os.path.join(download_path, n))
for n in self.model_files.keys()
):
self.downloader = ModelDownloader(
model_name="facedet.onnx",
model_name="facedet",
download_path=download_path,
file_names=model_files,
file_names=self.model_files.keys(),
download_func=self.__download_models,
complete_func=self.__build_detector,
)
self.downloader.ensure_model_files()
else:
@ -193,9 +195,9 @@ class FaceClassificationModel:
def __download_models(self, path: str) -> None:
try:
file_name = os.path.basename(path)
ModelDownloader.download_from_url(self.model_urls[file_name], path)
except Exception:
pass
ModelDownloader.download_from_url(self.model_files[file_name], path)
except Exception as e:
logger.error(f"Failed to download {path}: {e}")
def __build_detector(self) -> None:
self.face_detector = cv2.FaceDetectorYN.create(
@ -209,6 +211,9 @@ class FaceClassificationModel:
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
def __build_classifier(self) -> None:
if not self.landmark_detector:
return None
labels = []
faces = []