mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Use arcface face embeddings instead of generic embeddings model
This commit is contained in:
parent
bd95f1d270
commit
5c030fa460
@ -32,6 +32,7 @@ ws4py == 0.5.*
|
|||||||
unidecode == 1.3.*
|
unidecode == 1.3.*
|
||||||
# OpenVino & ONNX
|
# OpenVino & ONNX
|
||||||
openvino == 2024.3.*
|
openvino == 2024.3.*
|
||||||
|
onnx == 1.17.*
|
||||||
onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64'
|
onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64'
|
||||||
onnxruntime == 1.19.* ; platform_machine == 'aarch64'
|
onnxruntime == 1.19.* ; platform_machine == 'aarch64'
|
||||||
# Embeddings
|
# Embeddings
|
||||||
|
|||||||
@ -59,6 +59,6 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
|||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
face_embedding FLOAT[768] distance_metric=cosine
|
face_embedding FLOAT[512] distance_metric=cosine
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
|||||||
@ -124,6 +124,21 @@ class Embeddings:
|
|||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.face_embedding = None
|
||||||
|
|
||||||
|
if self.config.face_recognition.enabled:
|
||||||
|
self.face_embedding = GenericONNXEmbedding(
|
||||||
|
model_name="resnet100/arcface",
|
||||||
|
model_file="arcfaceresnet100-8.onnx",
|
||||||
|
download_urls={
|
||||||
|
"arcfaceresnet100-8.onnx": "https://media.githubusercontent.com/media/onnx/models/bb0d4cf3d4e2a5f7376c13a08d337e86296edbe8/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx"
|
||||||
|
},
|
||||||
|
model_size="large",
|
||||||
|
model_type=ModelTypeEnum.face,
|
||||||
|
requestor=self.requestor,
|
||||||
|
device="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
def embed_thumbnail(
|
def embed_thumbnail(
|
||||||
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
||||||
) -> ndarray:
|
) -> ndarray:
|
||||||
@ -219,9 +234,7 @@ class Embeddings:
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
|
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
|
||||||
# Convert thumbnail bytes to PIL Image
|
embedding = self.face_embedding(thumbnail)[0]
|
||||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
|
||||||
embedding = self.vision_embedding([image])[0]
|
|
||||||
|
|
||||||
if upsert:
|
if upsert:
|
||||||
rand_id = "".join(
|
rand_id = "".join(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from frigate.comms.inter_process import InterProcessRequestor
|
|||||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.downloader import ModelDownloader
|
from frigate.util.downloader import ModelDownloader
|
||||||
from frigate.util.model import ONNXModelRunner
|
from frigate.util.model import ONNXModelRunner, fix_spatial_mode
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
@ -47,7 +47,7 @@ class GenericONNXEmbedding:
|
|||||||
model_file: str,
|
model_file: str,
|
||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
model_size: str,
|
model_size: str,
|
||||||
model_type: str,
|
model_type: ModelTypeEnum,
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
tokenizer_file: Optional[str] = None,
|
tokenizer_file: Optional[str] = None,
|
||||||
device: str = "AUTO",
|
device: str = "AUTO",
|
||||||
@ -57,7 +57,7 @@ class GenericONNXEmbedding:
|
|||||||
self.tokenizer_file = tokenizer_file
|
self.tokenizer_file = tokenizer_file
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type
|
||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
@ -93,14 +93,19 @@ class GenericONNXEmbedding:
|
|||||||
def _download_model(self, path: str):
|
def _download_model(self, path: str):
|
||||||
try:
|
try:
|
||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
|
download_path = None
|
||||||
|
|
||||||
if file_name in self.download_urls:
|
if file_name in self.download_urls:
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
download_path = ModelDownloader.download_from_url(
|
||||||
|
self.download_urls[file_name], path
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
file_name == self.tokenizer_file
|
file_name == self.tokenizer_file
|
||||||
and self.model_type == ModelTypeEnum.text
|
and self.model_type == ModelTypeEnum.text
|
||||||
):
|
):
|
||||||
if not os.path.exists(path + "/" + self.model_name):
|
if not os.path.exists(path + "/" + self.model_name):
|
||||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -109,6 +114,12 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
tokenizer.save_pretrained(path)
|
tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
# the onnx model has incorrect spatial mode
|
||||||
|
# set by default, update then save model.
|
||||||
|
print(f"download path is {download_path} and model type is {self.model_type}")
|
||||||
|
if download_path is not None and self.model_type == ModelTypeEnum.face:
|
||||||
|
fix_spatial_mode(download_path)
|
||||||
|
|
||||||
self.downloader.requestor.send_data(
|
self.downloader.requestor.send_data(
|
||||||
UPDATE_MODEL_STATE,
|
UPDATE_MODEL_STATE,
|
||||||
{
|
{
|
||||||
@ -131,8 +142,11 @@ class GenericONNXEmbedding:
|
|||||||
self.downloader.wait_for_download()
|
self.downloader.wait_for_download()
|
||||||
if self.model_type == ModelTypeEnum.text:
|
if self.model_type == ModelTypeEnum.text:
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
else:
|
elif self.model_type == ModelTypeEnum.vision:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
|
elif self.model_type == ModelTypeEnum.face:
|
||||||
|
self.feature_extractor = []
|
||||||
|
|
||||||
self.runner = ONNXModelRunner(
|
self.runner = ONNXModelRunner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
@ -172,16 +186,37 @@ class GenericONNXEmbedding:
|
|||||||
self.feature_extractor(images=image, return_tensors="np")
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
for image in processed_images
|
for image in processed_images
|
||||||
]
|
]
|
||||||
|
elif self.model_type == ModelTypeEnum.face:
|
||||||
|
if isinstance(raw_inputs, list):
|
||||||
|
raise ValueError("Face embedding does not support batch inputs.")
|
||||||
|
|
||||||
|
pil = self._process_image(raw_inputs)
|
||||||
|
og = np.array(pil).astype(np.float32)
|
||||||
|
|
||||||
|
# Image must be 112x112
|
||||||
|
og_h, og_w, channels = og.shape
|
||||||
|
frame = np.full((112, 112, channels), (0, 0, 0), dtype=np.float32)
|
||||||
|
|
||||||
|
# compute center offset
|
||||||
|
x_center = (112 - og_w) // 2
|
||||||
|
y_center = (112 - og_h) // 2
|
||||||
|
|
||||||
|
# copy img image into center of result image
|
||||||
|
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
|
||||||
|
|
||||||
|
frame = np.expand_dims(frame, axis=0)
|
||||||
|
frame = np.transpose(frame, (0, 3, 1, 2))
|
||||||
|
return [{"data": frame}]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
||||||
|
|
||||||
def _process_image(self, image):
|
def _process_image(self, image, output: str = "RGB") -> Image.Image:
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
if image.startswith("http"):
|
if image.startswith("http"):
|
||||||
response = requests.get(image)
|
response = requests.get(image)
|
||||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
image = Image.open(BytesIO(response.content)).convert(output)
|
||||||
elif isinstance(image, bytes):
|
elif isinstance(image, bytes):
|
||||||
image = Image.open(BytesIO(image)).convert("RGB")
|
image = Image.open(BytesIO(image)).convert(output)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|||||||
@ -101,7 +101,7 @@ class ModelDownloader:
|
|||||||
self.download_complete.set()
|
self.download_complete.set()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_from_url(url: str, save_path: str, silent: bool = False):
|
def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
|
||||||
temporary_filename = Path(save_path).with_name(
|
temporary_filename = Path(save_path).with_name(
|
||||||
os.path.basename(save_path) + ".part"
|
os.path.basename(save_path) + ".part"
|
||||||
)
|
)
|
||||||
@ -125,6 +125,8 @@ class ModelDownloader:
|
|||||||
if not silent:
|
if not silent:
|
||||||
logger.info(f"Downloading complete: {url}")
|
logger.info(f"Downloading complete: {url}")
|
||||||
|
|
||||||
|
return Path(save_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_files_state(
|
def mark_files_state(
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
"""Model Utils"""
|
"""Model Utils"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import onnx
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -63,6 +65,23 @@ def get_ort_providers(
|
|||||||
return (providers, options)
|
return (providers, options)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_spatial_mode(path: Path) -> None:
|
||||||
|
save_path = str(path)
|
||||||
|
old_path = f"{save_path}.old"
|
||||||
|
path.rename(old_path)
|
||||||
|
|
||||||
|
model = onnx.load(old_path)
|
||||||
|
|
||||||
|
for node in model.graph.node:
|
||||||
|
if node.op_type == "BatchNormalization":
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name == "spatial":
|
||||||
|
attr.i = 1
|
||||||
|
|
||||||
|
onnx.save(model, save_path)
|
||||||
|
Path(old_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
class ONNXModelRunner:
|
class ONNXModelRunner:
|
||||||
"""Run onnx models optimally based on available hardware."""
|
"""Run onnx models optimally based on available hardware."""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user