import torch import os from transformers import AutoTokenizer from jax import numpy as jnp import json import requests import zipfile import io import natsort from PIL import Image as PilImage from tqdm import tqdm class CustomDataSet(torch.utils.data.Dataset): def __init__(self, main_dir, transform): self.main_dir = main_dir self.transform = transform all_imgs = os.listdir(main_dir) self.total_imgs = natsort.natsorted(all_imgs) def __len__(self): return len(self.total_imgs) def get_image_name(self, idx): return self.total_imgs[idx] def __getitem__(self, idx): img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) image = PilImage.open(img_loc).convert("RGB") tensor_image = self.transform(image) return tensor_image def text_encoder(text, model, tokenizer): inputs = tokenizer( [text], max_length=96, truncation=True, padding="max_length", return_tensors="np", ) embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[ 0 ] embedding /= jnp.linalg.norm(embedding) return jnp.expand_dims(embedding, axis=0) def precompute_image_features(model, loader): image_features = [] for i, (images) in enumerate(tqdm(loader)): images = images.permute(0, 2, 3, 1).numpy() features = model.get_image_features( images, ) features /= jnp.linalg.norm(features, axis=-1, keepdims=True) image_features.extend(features) return jnp.array(image_features) def find_image(text_query, model, dataset, tokenizer, image_features, n=1): zeroshot_weights = text_encoder(text_query, model, tokenizer) zeroshot_weights /= jnp.linalg.norm(zeroshot_weights) distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1)) file_paths = [] for i in range(1, n + 1): idx = jnp.argsort(distances, axis=0)[-i, 0] file_paths.append("photos/" + dataset.get_image_name(idx)) return file_paths