mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 09:04:28 +03:00
Handle download and loading correctly
This commit is contained in:
parent
2c182d59f0
commit
ba9a00364d
@ -51,12 +51,14 @@ class ModelDownloader:
|
||||
download_path: str,
|
||||
file_names: List[str],
|
||||
download_func: Callable[[str], None],
|
||||
complete_func: Callable[[], None] | None = None,
|
||||
silent: bool = False,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.download_path = download_path
|
||||
self.file_names = file_names
|
||||
self.download_func = download_func
|
||||
self.complete_func = complete_func
|
||||
self.silent = silent
|
||||
self.requestor = InterProcessRequestor()
|
||||
self.download_thread = None
|
||||
@ -97,6 +99,9 @@ class ModelDownloader:
|
||||
},
|
||||
)
|
||||
|
||||
if self.complete_func:
|
||||
self.complete_func()
|
||||
|
||||
self.requestor.stop()
|
||||
self.download_complete.set()
|
||||
|
||||
|
||||
@ -169,19 +169,21 @@ class FaceClassificationModel:
|
||||
self.face_recognizer: cv2.face.LBPHFaceRecognizer = None
|
||||
|
||||
download_path = os.path.join(MODEL_CACHE_DIR, "facedet")
|
||||
model_files = ["facedet.onnx", "landmarkdet.yaml"]
|
||||
self.model_urls = [
|
||||
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
|
||||
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
|
||||
]
|
||||
self.model_files = {
|
||||
"facedet.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
|
||||
"landmarkdet.yaml": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
|
||||
}
|
||||
|
||||
if not all(os.path.exists(os.path.join(download_path, n)) for n in model_files):
|
||||
logger.debug(f"starting model download for {self.model_name}")
|
||||
if not all(
|
||||
os.path.exists(os.path.join(download_path, n))
|
||||
for n in self.model_files.keys()
|
||||
):
|
||||
self.downloader = ModelDownloader(
|
||||
model_name="facedet.onnx",
|
||||
model_name="facedet",
|
||||
download_path=download_path,
|
||||
file_names=model_files,
|
||||
file_names=self.model_files.keys(),
|
||||
download_func=self.__download_models,
|
||||
complete_func=self.__build_detector,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
else:
|
||||
@ -193,9 +195,9 @@ class FaceClassificationModel:
|
||||
def __download_models(self, path: str) -> None:
|
||||
try:
|
||||
file_name = os.path.basename(path)
|
||||
ModelDownloader.download_from_url(self.model_urls[file_name], path)
|
||||
except Exception:
|
||||
pass
|
||||
ModelDownloader.download_from_url(self.model_files[file_name], path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {path}: {e}")
|
||||
|
||||
def __build_detector(self) -> None:
|
||||
self.face_detector = cv2.FaceDetectorYN.create(
|
||||
@ -209,6 +211,9 @@ class FaceClassificationModel:
|
||||
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
|
||||
|
||||
def __build_classifier(self) -> None:
|
||||
if not self.landmark_detector:
|
||||
return None
|
||||
|
||||
labels = []
|
||||
faces = []
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user