diff --git a/docker/main/requirements-wheels.txt b/docker/main/requirements-wheels.txt index 11ad94f3f..02dd62795 100644 --- a/docker/main/requirements-wheels.txt +++ b/docker/main/requirements-wheels.txt @@ -32,6 +32,7 @@ ws4py == 0.5.* unidecode == 1.3.* # OpenVino & ONNX openvino == 2024.3.* +onnx == 1.17.* onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64' onnxruntime == 1.19.* ; platform_machine == 'aarch64' # Embeddings diff --git a/frigate/db/sqlitevecq.py b/frigate/db/sqlitevecq.py index b852e06e5..d123edea8 100644 --- a/frigate/db/sqlitevecq.py +++ b/frigate/db/sqlitevecq.py @@ -59,6 +59,6 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): self.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0( id TEXT PRIMARY KEY, - face_embedding FLOAT[768] distance_metric=cosine + face_embedding FLOAT[512] distance_metric=cosine ); """) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 43658168f..b8e0b21a3 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -124,6 +124,21 @@ class Embeddings: device="GPU" if config.model_size == "large" else "CPU", ) + self.face_embedding = None + + if self.config.face_recognition.enabled: + self.face_embedding = GenericONNXEmbedding( + model_name="resnet100/arcface", + model_file="arcfaceresnet100-8.onnx", + download_urls={ + "arcfaceresnet100-8.onnx": "https://media.githubusercontent.com/media/onnx/models/bb0d4cf3d4e2a5f7376c13a08d337e86296edbe8/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx" + }, + model_size="large", + model_type=ModelTypeEnum.face, + requestor=self.requestor, + device="GPU", + ) + def embed_thumbnail( self, event_id: str, thumbnail: bytes, upsert: bool = True ) -> ndarray: @@ -219,9 +234,7 @@ class Embeddings: return embeddings def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray: - # Convert thumbnail bytes to PIL Image - image = Image.open(io.BytesIO(thumbnail)).convert("RGB") - embedding = self.vision_embedding([image])[0] + embedding = self.face_embedding(thumbnail)[0] if upsert: rand_id = "".join( diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 6ea495a30..241bbbe11 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -19,7 +19,7 @@ from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader -from frigate.util.model import ONNXModelRunner +from frigate.util.model import ONNXModelRunner, fix_spatial_mode warnings.filterwarnings( "ignore", @@ -47,7 +47,7 @@ class GenericONNXEmbedding: model_file: str, download_urls: Dict[str, str], model_size: str, - model_type: str, + model_type: ModelTypeEnum, requestor: InterProcessRequestor, tokenizer_file: Optional[str] = None, device: str = "AUTO", @@ -57,7 +57,7 @@ class GenericONNXEmbedding: self.tokenizer_file = tokenizer_file self.requestor = requestor self.download_urls = download_urls - self.model_type = model_type # 'text' or 'vision' + self.model_type = model_type self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) @@ -93,14 +93,19 @@ class GenericONNXEmbedding: def _download_model(self, path: str): try: file_name = os.path.basename(path) + download_path = None + if file_name in self.download_urls: - ModelDownloader.download_from_url(self.download_urls[file_name], path) + download_path = ModelDownloader.download_from_url( + self.download_urls[file_name], path + ) elif ( file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text ): if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer") + tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, @@ -109,6 +114,12 @@ class GenericONNXEmbedding: ) tokenizer.save_pretrained(path) + # the onnx model has incorrect spatial mode + # set by default, update then save model. + print(f"download path is {download_path} and model type is {self.model_type}") + if download_path is not None and self.model_type == ModelTypeEnum.face: + fix_spatial_mode(download_path) + self.downloader.requestor.send_data( UPDATE_MODEL_STATE, { @@ -131,8 +142,11 @@ class GenericONNXEmbedding: self.downloader.wait_for_download() if self.model_type == ModelTypeEnum.text: self.tokenizer = self._load_tokenizer() - else: + elif self.model_type == ModelTypeEnum.vision: self.feature_extractor = self._load_feature_extractor() + elif self.model_type == ModelTypeEnum.face: + self.feature_extractor = [] + self.runner = ONNXModelRunner( os.path.join(self.download_path, self.model_file), self.device, @@ -172,16 +186,37 @@ class GenericONNXEmbedding: self.feature_extractor(images=image, return_tensors="np") for image in processed_images ] + elif self.model_type == ModelTypeEnum.face: + if isinstance(raw_inputs, list): + raise ValueError("Face embedding does not support batch inputs.") + + pil = self._process_image(raw_inputs) + og = np.array(pil).astype(np.float32) + + # Image must be 112x112 + og_h, og_w, channels = og.shape + frame = np.full((112, 112, channels), (0, 0, 0), dtype=np.float32) + + # compute center offset + x_center = (112 - og_w) // 2 + y_center = (112 - og_h) // 2 + + # copy img image into center of result image + frame[y_center : y_center + og_h, x_center : x_center + og_w] = og + + frame = np.expand_dims(frame, axis=0) + frame = np.transpose(frame, (0, 3, 1, 2)) + return [{"data": frame}] else: raise ValueError(f"Unable to preprocess inputs for {self.model_type}") - def _process_image(self, image): + def _process_image(self, image, output: str = "RGB") -> Image.Image: if isinstance(image, str): if image.startswith("http"): response = requests.get(image) - image = Image.open(BytesIO(response.content)).convert("RGB") + image = Image.open(BytesIO(response.content)).convert(output) elif isinstance(image, bytes): - image = Image.open(BytesIO(image)).convert("RGB") + image = Image.open(BytesIO(image)).convert(output) return image diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 6685b0bb8..18c577fb0 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -101,7 +101,7 @@ class ModelDownloader: self.download_complete.set() @staticmethod - def download_from_url(url: str, save_path: str, silent: bool = False): + def download_from_url(url: str, save_path: str, silent: bool = False) -> Path: temporary_filename = Path(save_path).with_name( os.path.basename(save_path) + ".part" ) @@ -125,6 +125,8 @@ class ModelDownloader: if not silent: logger.info(f"Downloading complete: {url}") + return Path(save_path) + @staticmethod def mark_files_state( requestor: InterProcessRequestor, diff --git a/frigate/util/model.py b/frigate/util/model.py index 7aefe8b42..22a3ff099 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -1,8 +1,10 @@ """Model Utils""" import os +from pathlib import Path from typing import Any +import onnx import onnxruntime as ort try: @@ -63,6 +65,23 @@ def get_ort_providers( return (providers, options) +def fix_spatial_mode(path: Path) -> None: + save_path = str(path) + old_path = f"{save_path}.old" + path.rename(old_path) + + model = onnx.load(old_path) + + for node in model.graph.node: + if node.op_type == "BatchNormalization": + for attr in node.attribute: + if attr.name == "spatial": + attr.i = 1 + + onnx.save(model, save_path) + Path(old_path).unlink() + + class ONNXModelRunner: """Run onnx models optimally based on available hardware."""