from datasets import load_dataset import torch from transformers import AutoProcessor, AutoModelForZeroShotImageClassification from loadimg import load_img device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol) processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device) class Instance: def __init__(self, dataset, token=None, split="train"): self.dataset = dataset self.token = token self.split = split self.data = load_dataset(self.dataset, split=self.split) self.data = self.data.add_faiss_index("embeddings") def embed(batch): """a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)""" pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values'] pixel_values = pixel_values.to(device) img_emb = model.get_image_features(pixel_values) batch["embeddings"] = img_emb return batch def search(self, query: str, k: int = 3 ): """ A function that embeds a query image and returns the most probable results. Args: query: the image to search for k: the number of results to return Returns: scores: the scores of the retrieved examples (cosine similarity i think in this case) retrieved_examples: the retrieved examples """ pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] pixel_values = pixel_values.to(device) img_emb = model.get_image_features(pixel_values)[0] img_emb = img_emb.cpu().detach().numpy() scores, retrieved_examples = self.data.get_nearest_examples( "embeddings", img_emb, k=k ) return scores, retrieved_examples def high_level_search(self, img): """ High level wrapper for the search function. Args: img: input image (path, url, pillow or numpy) Returns: scores: the scores of the retrieved examples (cosine similarity i think in this case) retrieved_examples: the retrieved examples """ image = load_img(img) scores, retrieved_examples = self.search(image)