mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 09:04:28 +03:00
Handle download and loading correctly
This commit is contained in:
parent
2c182d59f0
commit
ba9a00364d
@ -51,12 +51,14 @@ class ModelDownloader:
|
|||||||
download_path: str,
|
download_path: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
download_func: Callable[[str], None],
|
download_func: Callable[[str], None],
|
||||||
|
complete_func: Callable[[], None] | None = None,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.download_path = download_path
|
self.download_path = download_path
|
||||||
self.file_names = file_names
|
self.file_names = file_names
|
||||||
self.download_func = download_func
|
self.download_func = download_func
|
||||||
|
self.complete_func = complete_func
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.requestor = InterProcessRequestor()
|
self.requestor = InterProcessRequestor()
|
||||||
self.download_thread = None
|
self.download_thread = None
|
||||||
@ -97,6 +99,9 @@ class ModelDownloader:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.complete_func:
|
||||||
|
self.complete_func()
|
||||||
|
|
||||||
self.requestor.stop()
|
self.requestor.stop()
|
||||||
self.download_complete.set()
|
self.download_complete.set()
|
||||||
|
|
||||||
|
|||||||
@ -169,19 +169,21 @@ class FaceClassificationModel:
|
|||||||
self.face_recognizer: cv2.face.LBPHFaceRecognizer = None
|
self.face_recognizer: cv2.face.LBPHFaceRecognizer = None
|
||||||
|
|
||||||
download_path = os.path.join(MODEL_CACHE_DIR, "facedet")
|
download_path = os.path.join(MODEL_CACHE_DIR, "facedet")
|
||||||
model_files = ["facedet.onnx", "landmarkdet.yaml"]
|
self.model_files = {
|
||||||
self.model_urls = [
|
"facedet.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
|
||||||
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facedet.onnx",
|
"landmarkdet.yaml": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
|
||||||
"https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/landmarkdet.yaml",
|
}
|
||||||
]
|
|
||||||
|
|
||||||
if not all(os.path.exists(os.path.join(download_path, n)) for n in model_files):
|
if not all(
|
||||||
logger.debug(f"starting model download for {self.model_name}")
|
os.path.exists(os.path.join(download_path, n))
|
||||||
|
for n in self.model_files.keys()
|
||||||
|
):
|
||||||
self.downloader = ModelDownloader(
|
self.downloader = ModelDownloader(
|
||||||
model_name="facedet.onnx",
|
model_name="facedet",
|
||||||
download_path=download_path,
|
download_path=download_path,
|
||||||
file_names=model_files,
|
file_names=self.model_files.keys(),
|
||||||
download_func=self.__download_models,
|
download_func=self.__download_models,
|
||||||
|
complete_func=self.__build_detector,
|
||||||
)
|
)
|
||||||
self.downloader.ensure_model_files()
|
self.downloader.ensure_model_files()
|
||||||
else:
|
else:
|
||||||
@ -193,9 +195,9 @@ class FaceClassificationModel:
|
|||||||
def __download_models(self, path: str) -> None:
|
def __download_models(self, path: str) -> None:
|
||||||
try:
|
try:
|
||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
ModelDownloader.download_from_url(self.model_urls[file_name], path)
|
ModelDownloader.download_from_url(self.model_files[file_name], path)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.error(f"Failed to download {path}: {e}")
|
||||||
|
|
||||||
def __build_detector(self) -> None:
|
def __build_detector(self) -> None:
|
||||||
self.face_detector = cv2.FaceDetectorYN.create(
|
self.face_detector = cv2.FaceDetectorYN.create(
|
||||||
@ -209,6 +211,9 @@ class FaceClassificationModel:
|
|||||||
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
|
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
|
||||||
|
|
||||||
def __build_classifier(self) -> None:
|
def __build_classifier(self) -> None:
|
||||||
|
if not self.landmark_detector:
|
||||||
|
return None
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
faces = []
|
faces = []
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user