Use arcface face embeddings instead of generic embeddings model

This commit is contained in:
Nicolas Mowen 2024-10-21 16:10:55 -06:00
parent bd95f1d270
commit 5c030fa460
6 changed files with 83 additions and 13 deletions

View File

@ -32,6 +32,7 @@ ws4py == 0.5.*
unidecode == 1.3.* unidecode == 1.3.*
# OpenVino & ONNX # OpenVino & ONNX
openvino == 2024.3.* openvino == 2024.3.*
onnx == 1.17.*
onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64' onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64'
onnxruntime == 1.19.* ; platform_machine == 'aarch64' onnxruntime == 1.19.* ; platform_machine == 'aarch64'
# Embeddings # Embeddings

View File

@ -59,6 +59,6 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
self.execute_sql(""" self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
face_embedding FLOAT[768] distance_metric=cosine face_embedding FLOAT[512] distance_metric=cosine
); );
""") """)

View File

@ -124,6 +124,21 @@ class Embeddings:
device="GPU" if config.model_size == "large" else "CPU", device="GPU" if config.model_size == "large" else "CPU",
) )
self.face_embedding = None
if self.config.face_recognition.enabled:
self.face_embedding = GenericONNXEmbedding(
model_name="resnet100/arcface",
model_file="arcfaceresnet100-8.onnx",
download_urls={
"arcfaceresnet100-8.onnx": "https://media.githubusercontent.com/media/onnx/models/bb0d4cf3d4e2a5f7376c13a08d337e86296edbe8/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx"
},
model_size="large",
model_type=ModelTypeEnum.face,
requestor=self.requestor,
device="GPU",
)
def embed_thumbnail( def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray: ) -> ndarray:
@ -219,9 +234,7 @@ class Embeddings:
return embeddings return embeddings
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray: def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
# Convert thumbnail bytes to PIL Image embedding = self.face_embedding(thumbnail)[0]
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
embedding = self.vision_embedding([image])[0]
if upsert: if upsert:
rand_id = "".join( rand_id = "".join(

View File

@ -19,7 +19,7 @@ from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
from frigate.util.model import ONNXModelRunner from frigate.util.model import ONNXModelRunner, fix_spatial_mode
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
@ -47,7 +47,7 @@ class GenericONNXEmbedding:
model_file: str, model_file: str,
download_urls: Dict[str, str], download_urls: Dict[str, str],
model_size: str, model_size: str,
model_type: str, model_type: ModelTypeEnum,
requestor: InterProcessRequestor, requestor: InterProcessRequestor,
tokenizer_file: Optional[str] = None, tokenizer_file: Optional[str] = None,
device: str = "AUTO", device: str = "AUTO",
@ -57,7 +57,7 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file self.tokenizer_file = tokenizer_file
self.requestor = requestor self.requestor = requestor
self.download_urls = download_urls self.download_urls = download_urls
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type
self.model_size = model_size self.model_size = model_size
self.device = device self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
@ -93,14 +93,19 @@ class GenericONNXEmbedding:
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
file_name = os.path.basename(path) file_name = os.path.basename(path)
download_path = None
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) download_path = ModelDownloader.download_from_url(
self.download_urls[file_name], path
)
elif ( elif (
file_name == self.tokenizer_file file_name == self.tokenizer_file
and self.model_type == ModelTypeEnum.text and self.model_type == ModelTypeEnum.text
): ):
if not os.path.exists(path + "/" + self.model_name): if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer") logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
self.model_name, self.model_name,
trust_remote_code=True, trust_remote_code=True,
@ -109,6 +114,12 @@ class GenericONNXEmbedding:
) )
tokenizer.save_pretrained(path) tokenizer.save_pretrained(path)
# the onnx model has incorrect spatial mode
# set by default, update then save model.
print(f"download path is {download_path} and model type is {self.model_type}")
if download_path is not None and self.model_type == ModelTypeEnum.face:
fix_spatial_mode(download_path)
self.downloader.requestor.send_data( self.downloader.requestor.send_data(
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
{ {
@ -131,8 +142,11 @@ class GenericONNXEmbedding:
self.downloader.wait_for_download() self.downloader.wait_for_download()
if self.model_type == ModelTypeEnum.text: if self.model_type == ModelTypeEnum.text:
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: elif self.model_type == ModelTypeEnum.vision:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
elif self.model_type == ModelTypeEnum.face:
self.feature_extractor = []
self.runner = ONNXModelRunner( self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
@ -172,16 +186,37 @@ class GenericONNXEmbedding:
self.feature_extractor(images=image, return_tensors="np") self.feature_extractor(images=image, return_tensors="np")
for image in processed_images for image in processed_images
] ]
elif self.model_type == ModelTypeEnum.face:
if isinstance(raw_inputs, list):
raise ValueError("Face embedding does not support batch inputs.")
pil = self._process_image(raw_inputs)
og = np.array(pil).astype(np.float32)
# Image must be 112x112
og_h, og_w, channels = og.shape
frame = np.full((112, 112, channels), (0, 0, 0), dtype=np.float32)
# compute center offset
x_center = (112 - og_w) // 2
y_center = (112 - og_h) // 2
# copy img image into center of result image
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
frame = np.expand_dims(frame, axis=0)
frame = np.transpose(frame, (0, 3, 1, 2))
return [{"data": frame}]
else: else:
raise ValueError(f"Unable to preprocess inputs for {self.model_type}") raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
def _process_image(self, image): def _process_image(self, image, output: str = "RGB") -> Image.Image:
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http"): if image.startswith("http"):
response = requests.get(image) response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB") image = Image.open(BytesIO(response.content)).convert(output)
elif isinstance(image, bytes): elif isinstance(image, bytes):
image = Image.open(BytesIO(image)).convert("RGB") image = Image.open(BytesIO(image)).convert(output)
return image return image

View File

@ -101,7 +101,7 @@ class ModelDownloader:
self.download_complete.set() self.download_complete.set()
@staticmethod @staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False): def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
temporary_filename = Path(save_path).with_name( temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part" os.path.basename(save_path) + ".part"
) )
@ -125,6 +125,8 @@ class ModelDownloader:
if not silent: if not silent:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {url}")
return Path(save_path)
@staticmethod @staticmethod
def mark_files_state( def mark_files_state(
requestor: InterProcessRequestor, requestor: InterProcessRequestor,

View File

@ -1,8 +1,10 @@
"""Model Utils""" """Model Utils"""
import os import os
from pathlib import Path
from typing import Any from typing import Any
import onnx
import onnxruntime as ort import onnxruntime as ort
try: try:
@ -63,6 +65,23 @@ def get_ort_providers(
return (providers, options) return (providers, options)
def fix_spatial_mode(path: Path) -> None:
save_path = str(path)
old_path = f"{save_path}.old"
path.rename(old_path)
model = onnx.load(old_path)
for node in model.graph.node:
if node.op_type == "BatchNormalization":
for attr in node.attribute:
if attr.name == "spatial":
attr.i = 1
onnx.save(model, save_path)
Path(old_path).unlink()
class ONNXModelRunner: class ONNXModelRunner:
"""Run onnx models optimally based on available hardware.""" """Run onnx models optimally based on available hardware."""