mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
remove chroma in clip model
This commit is contained in:
parent
4444d82089
commit
3c334175c7
@ -4,18 +4,13 @@ import errno
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import requests
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Images,
|
||||
is_document,
|
||||
is_image,
|
||||
)
|
||||
from onnx_clip import OnnxClip
|
||||
from PIL import Image
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
|
||||
@ -27,7 +22,7 @@ class Clip(OnnxClip):
|
||||
def _load_models(
|
||||
model: str,
|
||||
silent: bool,
|
||||
) -> Tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""
|
||||
These models are a part of the container. Treat as as such.
|
||||
"""
|
||||
@ -87,20 +82,22 @@ class Clip(OnnxClip):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
|
||||
|
||||
class ClipEmbedding(EmbeddingFunction):
|
||||
"""Embedding function for CLIP model used in Chroma."""
|
||||
class ClipEmbedding:
|
||||
"""Embedding function for CLIP model."""
|
||||
|
||||
def __init__(self, model: str = "ViT-B/32"):
|
||||
"""Initialize CLIP Embedding function."""
|
||||
self.model = Clip(model)
|
||||
|
||||
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
|
||||
embeddings: Embeddings = []
|
||||
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
|
||||
embeddings = []
|
||||
for item in input:
|
||||
if is_image(item):
|
||||
if isinstance(item, Image.Image):
|
||||
result = self.model.get_image_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
elif is_document(item):
|
||||
embeddings.append(result[0])
|
||||
elif isinstance(item, str):
|
||||
result = self.model.get_text_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
embeddings.append(result[0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(item)}")
|
||||
return embeddings
|
||||
|
||||
Loading…
Reference in New Issue
Block a user