clip-italian-demo / utils.py
4rtemi5's picture
Ad download progress bar...
3418bf2
raw history blame
No virus
2.19 kB
import os
import natsort
from tqdm import tqdm
import torch
from jax import numpy as jnp
from PIL import Image as PilImage
class CustomDataSet(torch.utils.data.Dataset):
def __init__(self, main_dir, transform):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
self.total_imgs = natsort.natsorted(all_imgs)
def __len__(self):
return len(self.total_imgs)
def get_image_name(self, idx):
return self.total_imgs[idx]
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = PilImage.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
return tensor_image
class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def text_encoder(text, model, tokenizer):
inputs = tokenizer(
[text],
max_length=96,
truncation=True,
padding="max_length",
return_tensors="np",
)
embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
0
]
embedding /= jnp.linalg.norm(embedding)
return jnp.expand_dims(embedding, axis=0)
def precompute_image_features(model, loader):
image_features = []
for i, (images) in enumerate(tqdm(loader)):
images = images.permute(0, 2, 3, 1).numpy()
features = model.get_image_features(
images,
)
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
image_features.extend(features)
return jnp.array(image_features)
def find_image(text_query, model, dataset, tokenizer, image_features, n=1):
zeroshot_weights = text_encoder(text_query, model, tokenizer)
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
file_paths = []
for i in range(1, n + 1):
idx = jnp.argsort(distances, axis=0)[-i, 0]
file_paths.append("photos/" + dataset.get_image_name(idx))
return file_paths