clip-italian-demo / utils.py
g8a9's picture
[image2text] add initial version
a01e989
raw
history blame
2.46 kB
import os
import natsort
from tqdm import tqdm
import torch
from jax import numpy as jnp
from PIL import Image as PilImage
class CustomDataSet(torch.utils.data.Dataset):
def __init__(self, main_dir, transform):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
self.total_imgs = natsort.natsorted(all_imgs)
def __len__(self):
return len(self.total_imgs)
def get_image_name(self, idx):
return self.total_imgs[idx]
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = PilImage.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
return tensor_image
def text_encoder(text, model, tokenizer):
inputs = tokenizer(
[text],
max_length=96,
truncation=True,
padding="max_length",
return_tensors="np",
)
embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
0
]
embedding /= jnp.linalg.norm(embedding)
return jnp.expand_dims(embedding, axis=0)
def image_encoder(image, model):
image = image.permute(1, 2, 0).numpy()
image = jnp.expand_dims(image, axis=0) #  add batch size
features = model.get_image_features(image,)
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
return features
def precompute_image_features(model, loader):
image_features = []
for i, (images) in enumerate(tqdm(loader)):
images = images.permute(0, 2, 3, 1).numpy()
features = model.get_image_features(images,)
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
image_features.extend(features)
return jnp.array(image_features)
def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
zeroshot_weights = text_encoder(text_query, model, tokenizer)
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
file_paths = []
for i in range(1, n + 1):
idx = jnp.argsort(distances, axis=0)[-i, 0]
if dataset_name == "Unsplash":
file_paths.append("photos/" + dataset.get_image_name(idx))
elif dataset_name == "CC":
file_paths.append(dataset[idx])
else:
raise ValueError(f"{dataset_name} not supported here")
return file_paths