|
import gradio as gr |
|
from datasets import load_dataset |
|
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel |
|
import torch |
|
from tqdm.auto import tqdm |
|
import numpy as np |
|
import time |
|
|
|
device = 'cpu' |
|
model_id = 'openai/clip-vit-base-patch32' |
|
model = CLIPModel.from_pretrained(model_id).to(device) |
|
tokenizer = CLIPTokenizerFast.from_pretrained(model_id) |
|
processor = CLIPProcessor.from_pretrained(model_id) |
|
|
|
|
|
|
|
def load_data(): |
|
global imagenette |
|
imagenette = load_dataset( |
|
'frgfm/imagenette', |
|
'full_size', |
|
split = 'train', |
|
ignore_verifications = False |
|
) |
|
return imagenette |
|
|
|
def embedding_input(text_input): |
|
token_input = tokenizer(text_input, return_tensors = "pt") |
|
text_emb = model.get_text_features(**token_input.to(device)) |
|
return text_emb |
|
|
|
def embedding_img(): |
|
global images, image_arr |
|
load_data() |
|
sample_idx= np.random.randint(0, len(imagenette)+1, 100).tolist() |
|
images = [imagenette[i]['image'] for i in sample_idx] |
|
batch_sie = 5 |
|
image_arr = None |
|
for i in tqdm(range(0, len(images), batch_sie)): |
|
time.sleep(1) |
|
batch = images[i:i+batch_sie] |
|
|
|
batch = processor( |
|
text = None, |
|
images = batch, |
|
return_tensors= 'pt', |
|
padding = True |
|
)['pixel_values'].to(device) |
|
batch_emb = model.get_image_features(pixel_values = batch) |
|
batch_emb = batch_emb.squeeze(0) |
|
batch_emb = batch_emb.cpu().detach().numpy() |
|
|
|
if image_arr is None: |
|
image_arr = batch_emb |
|
|
|
else: |
|
image_arr = np.concatenate((image_arr, batch_emb), axis = 0) |
|
return image_arr |
|
|
|
def norm_val(text_input): |
|
text_emb = embedding_input(text_input) |
|
image_emb = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T |
|
text_emb = text_emb.cpu().detach().numpy() |
|
scores = np.dot(text_emb, image_emb.T) |
|
top_k = 1 |
|
idx = np.argsort(-scores[0])[:top_k] |
|
return images[idx[0]] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
embedding_img() |
|
load_data() |
|
iface = gr.Interface(fn=norm_val, inputs="text", outputs="image") |
|
iface.launch(inline = False ) |
|
|