remove chroma in clip model

This commit is contained in:
Josh Hawkins 2024-10-04 13:21:29 -05:00
parent 4444d82089
commit 3c334175c7

View File

@ -4,18 +4,13 @@ import errno
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import List, Union
import numpy as np
import onnxruntime as ort import onnxruntime as ort
import requests import requests
from chromadb import EmbeddingFunction, Embeddings
from chromadb.api.types import (
Documents,
Images,
is_document,
is_image,
)
from onnx_clip import OnnxClip from onnx_clip import OnnxClip
from PIL import Image
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR
@ -27,7 +22,7 @@ class Clip(OnnxClip):
def _load_models( def _load_models(
model: str, model: str,
silent: bool, silent: bool,
) -> Tuple[ort.InferenceSession, ort.InferenceSession]: ) -> tuple[ort.InferenceSession, ort.InferenceSession]:
""" """
These models are a part of the container. Treat as as such. These models are a part of the container. Treat as as such.
""" """
@ -87,20 +82,22 @@ class Clip(OnnxClip):
return ort.InferenceSession(path, providers=providers) return ort.InferenceSession(path, providers=providers)
class ClipEmbedding(EmbeddingFunction): class ClipEmbedding:
"""Embedding function for CLIP model used in Chroma.""" """Embedding function for CLIP model."""
def __init__(self, model: str = "ViT-B/32"): def __init__(self, model: str = "ViT-B/32"):
"""Initialize CLIP Embedding function.""" """Initialize CLIP Embedding function."""
self.model = Clip(model) self.model = Clip(model)
def __call__(self, input: Union[Documents, Images]) -> Embeddings: def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
embeddings: Embeddings = [] embeddings = []
for item in input: for item in input:
if is_image(item): if isinstance(item, Image.Image):
result = self.model.get_image_embeddings([item]) result = self.model.get_image_embeddings([item])
embeddings.append(result[0, :].tolist()) embeddings.append(result[0])
elif is_document(item): elif isinstance(item, str):
result = self.model.get_text_embeddings([item]) result = self.model.get_text_embeddings([item])
embeddings.append(result[0, :].tolist()) embeddings.append(result[0])
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return embeddings return embeddings