import gradio as gr import torch import torch.nn.functional as F import requests import numpy as np import re import io import matplotlib.pyplot as plt from PIL import Image from transformers import ViltProcessor, ViltForMaskedLM from torchvision import transforms processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm") model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm") device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) class MinMaxResize: def __init__(self, shorter=800, longer=1333): self.min = shorter self.max = longer def __call__(self, x): w, h = x.size scale = self.min / min(w, h) if h < w: newh, neww = self.min, scale * w else: newh, neww = scale * h, self.min if max(newh, neww) > self.max: scale = self.max / max(newh, neww) newh = newh * scale neww = neww * scale newh, neww = int(newh + 0.5), int(neww + 0.5) newh, neww = newh // 32 * 32, neww // 32 * 32 return x.resize((neww, newh), resample=Image.Resampling.BICUBIC) def pixelbert_transform(size=800): longer = int((1333 / 800) * size) return transforms.Compose( [ MinMaxResize(shorter=size, longer=longer), transforms.ToTensor(), transforms.Compose([transforms.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]), ] ) def cost_matrix_cosine(x, y, eps=1e-5): """Compute cosine distnace across every pairs of x, y (batched) [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]""" assert x.dim() == y.dim() assert x.size(0) == y.size(0) assert x.size(2) == y.size(2) x_norm = F.normalize(x, p=2, dim=-1, eps=eps) y_norm = F.normalize(y, p=2, dim=-1, eps=eps) cosine_sim = x_norm.matmul(y_norm.transpose(1, 2)) cosine_dist = 1 - cosine_sim return cosine_dist @torch.no_grad() def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k): """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]""" b, m, n = C.size() sigma = torch.ones(b, m, dtype=C.dtype, device=C.device) / x_len.unsqueeze(1) T = torch.ones(b, n, m, dtype=C.dtype, device=C.device) A = torch.exp(-C.transpose(1, 2) / beta) # mask padded positions sigma.masked_fill_(x_pad, 0) joint_pad = joint_pad.transpose(1, 2) T.masked_fill_(joint_pad, 0) A.masked_fill_(joint_pad, 0) # broadcastable lengths x_len = x_len.unsqueeze(1).unsqueeze(2) y_len = y_len.unsqueeze(1).unsqueeze(2) # mask to zero out padding in delta and sigma x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1) y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1) for _ in range(iteration): Q = A * T # bs * n * m sigma = sigma.view(b, m, 1) for _ in range(k): delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask) sigma = 1 / (x_len * delta.matmul(Q) + x_mask) T = delta.view(b, n, 1) * Q * sigma T.masked_fill_(joint_pad, 0) return T def get_model_embedding_and_mask(model, input_ids, pixel_values): input_shape = input_ids.size() text_batch_size, seq_length = input_shape device = input_ids.device attention_mask = torch.ones(((text_batch_size, seq_length)), device=device) image_batch_size = pixel_values.shape[0] image_token_type_idx = 1 if image_batch_size != text_batch_size: raise ValueError( "The text inputs and image inputs need to have the same batch size") pixel_mask = torch.ones((image_batch_size, model.vilt.config.image_size, model.vilt.config.image_size), device=device) text_embeds = model.vilt.embeddings.text_embeddings( input_ids=input_ids, token_type_ids=None, inputs_embeds=None) image_embeds, image_masks, patch_index = model.vilt.embeddings.visual_embed( pixel_values=pixel_values, pixel_mask=pixel_mask, max_image_length=model.vilt.config.max_image_length ) text_embeds = text_embeds + model.vilt.embeddings.token_type_embeddings( torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device) ) image_embeds = image_embeds + model.vilt.embeddings.token_type_embeddings( torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device) ) return text_embeds, image_embeds, attention_mask, image_masks, patch_index def infer(url, mp_text, hidx): try: res = requests.get(url) image = Image.open(io.BytesIO(res.content)).convert("RGB") img = pixelbert_transform(size=500)(image) img = img.unsqueeze(0).to(device) except: return False tl = len(re.findall("\[MASK\]", mp_text)) inferred_token = [mp_text] encoding = processor(image, mp_text, return_tensors="pt") with torch.no_grad(): for i in range(tl): encoded = processor.tokenizer(inferred_token) input_ids = torch.tensor(encoded.input_ids) encoded = encoded["input_ids"][0][1:-1] outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values) mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size) # only take into account text features (minus CLS and SEP token) mlm_logits = mlm_logits[1: input_ids.shape[1] - 1, :] mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1) # only take into account text mlm_values[torch.tensor(encoded) != 103] = 0 select = mlm_values.argmax().item() encoded[select] = mlm_ids[select].item() inferred_token = [processor.decode(encoded)] encoded = processor.tokenizer(inferred_token) output = processor.decode(encoded.input_ids[0], skip_special_tokens=True) selected_token = '' result = Image.open('no_heatmap.jpg') if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]): input_ids = torch.tensor(encoded.input_ids) outputs = model( input_ids=input_ids, pixel_values=encoding.pixel_values, output_hidden_states=True) txt_emb, img_emb, text_masks, image_masks, patch_index = get_model_embedding_and_mask( model, input_ids=input_ids, pixel_values=encoding.pixel_values) embedding_output = torch.cat([txt_emb, img_emb], dim=1) attention_mask = torch.cat([text_masks, image_masks], dim=1) extended_attention_mask = model.vilt.get_extended_attention_mask( attention_mask, input_ids.size(), device=device) encoder_outputs = model.vilt.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=None, output_attentions=False, output_hidden_states=True, return_dict=True, ) x = encoder_outputs.hidden_states[-1] x = model.vilt.layernorm(x) txt_emb, img_emb = ( x[:, :txt_emb.shape[1]], x[:, txt_emb.shape[1]:], ) txt_mask, img_mask = ( text_masks.bool(), image_masks.bool(), ) for i, _len in enumerate(txt_mask.sum(dim=1)): txt_mask[i, _len - 1] = False txt_mask[:, 0] = False img_mask[:, 0] = False txt_pad, img_pad = ~txt_mask, ~img_mask cost = cost_matrix_cosine(txt_emb.float(), img_emb.float()) joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) cost.masked_fill_(joint_pad, 0) txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(dtype=cost.dtype) img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(dtype=cost.dtype) T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 0.1, 1000, 1, ) plan = T[0] plan_single = plan * len(txt_emb) cost_ = plan_single.t() cost_ = cost_[hidx][1:].cpu() patch_index, (H, W) = patch_index heatmap = torch.zeros(H, W) for i, pidx in enumerate(patch_index[0]): h, w = pidx[0].item(), pidx[1].item() heatmap[h, w] = cost_[i] heatmap = (heatmap - heatmap.mean()) / heatmap.std() heatmap = np.clip(heatmap, 1.0, 3.0) heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) _w, _h = image.size overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize( (_w, _h), resample=Image.Resampling.NEAREST ) image_rgba = image.copy() image_rgba.putalpha(overlay) result = image_rgba selected_token = processor.tokenizer.convert_ids_to_tokens( encoded["input_ids"][0][hidx] ) return [np.array(image), output, selected_token, result] title = "What's in the picture ?" description = """ Can't find your words to describe an image ? The pre-trained ViLT model will help you. Give the url of an image and a caption with [MASK] tokens to be filled or play with the given examples ! You can even see where the model focused its attention for a given word : just choose the index of the selected word with the slider. """ inputs_interface = [ gr.inputs.Textbox( label="Url of an image.", lines=5, ), gr.inputs.Textbox( label="Caption with [MASK] tokens to be filled.", lines=5), gr.inputs.Slider( minimum=0, maximum=38, step=1, label="Index of token for heatmap visualization (ignored if zero)", ), ] outputs_interface = [ gr.outputs.Image(label="Image"), gr.outputs.Textbox(label="description"), gr.outputs.Textbox(label="selected token"), gr.outputs.Image(label="Heatmap") ] interface = gr.Interface( fn=infer, inputs=inputs_interface, outputs=outputs_interface, title=title, description=description, server_name="0.0.0.0", server_port=8888, examples=[ [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].", 0, ], [ "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT5W71UTcSBm3r5l9NzBemglq983bYvKOHRkw&usqp=CAU", "An [MASK] with the [MASK] in the [MASK].", 5, ], [ "https://www.referenseo.com/wp-content/uploads/2019/03/image-attractive-960x540.jpg", "An [MASK] is flying with a [MASK] over a [MASK].", 2, ], ], ) interface.launch()