|
import os |
|
import natsort |
|
from tqdm import tqdm |
|
import torch |
|
from jax import numpy as jnp |
|
from PIL import Image as PilImage |
|
|
|
|
|
class CustomDataSet(torch.utils.data.Dataset): |
|
def __init__(self, main_dir, transform): |
|
self.main_dir = main_dir |
|
self.transform = transform |
|
all_imgs = os.listdir(main_dir) |
|
self.total_imgs = natsort.natsorted(all_imgs) |
|
|
|
def __len__(self): |
|
return len(self.total_imgs) |
|
|
|
def get_image_name(self, idx): |
|
return self.total_imgs[idx] |
|
|
|
def __getitem__(self, idx): |
|
img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) |
|
image = PilImage.open(img_loc).convert("RGB") |
|
tensor_image = self.transform(image) |
|
return tensor_image |
|
|
|
|
|
def text_encoder(text, model, tokenizer): |
|
inputs = tokenizer( |
|
[text], |
|
max_length=96, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="np", |
|
) |
|
embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[ |
|
0 |
|
] |
|
embedding /= jnp.linalg.norm(embedding) |
|
return jnp.expand_dims(embedding, axis=0) |
|
|
|
|
|
def image_encoder(image, model): |
|
image = image.permute(1, 2, 0).numpy() |
|
image = jnp.expand_dims(image, axis=0) |
|
features = model.get_image_features(image,) |
|
features /= jnp.linalg.norm(features, axis=-1, keepdims=True) |
|
return features |
|
|
|
|
|
def precompute_image_features(model, loader): |
|
image_features = [] |
|
for i, (images) in enumerate(tqdm(loader)): |
|
images = images.permute(0, 2, 3, 1).numpy() |
|
features = model.get_image_features(images,) |
|
features /= jnp.linalg.norm(features, axis=-1, keepdims=True) |
|
image_features.extend(features) |
|
return jnp.array(image_features) |
|
|
|
|
|
def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name): |
|
zeroshot_weights = text_encoder(text_query, model, tokenizer) |
|
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights) |
|
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1)) |
|
file_paths = [] |
|
for i in range(1, n + 1): |
|
idx = jnp.argsort(distances, axis=0)[-i, 0] |
|
|
|
if dataset_name == "Unsplash": |
|
file_paths.append("photos/" + dataset.get_image_name(idx)) |
|
elif dataset_name == "CC": |
|
file_paths.append(dataset[idx]) |
|
else: |
|
raise ValueError(f"{dataset_name} not supported here") |
|
return file_paths |
|
|