import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering import torch import gradio as gr import torch import copy import time import requests import io import numpy as np import re from PIL import Image from vilt.config import ex from vilt.modules import ViLTransformerSS from vilt.modules.objectives import cost_matrix_cosine, ipot from vilt.transforms import pixelbert_transform from vilt.datamodules.datamodule_base import get_pretrained_tokenizer @ex.automain def main(_config): _config = copy.deepcopy(_config) loss_names = { "itm": 0, "mlm": 0.5, "mpp": 0, "vqa": 0, "imgcls": 0, "nlvr2": 0, "irtr": 0, "arc": 0, } tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) _config.update( { "loss_names": loss_names, } ) model = ViLTransformerSS(_config) model.setup("test") model.eval() device = "cpu" model.to(device) def infer(url, mp_text, hidx): try: res = requests.get(url) image = Image.open(io.BytesIO(res.content)).convert("RGB") img = pixelbert_transform(size=384)(image) img = img.unsqueeze(0).to(device) except: return False batch = {"text": [""], "image": [None]} tl = len(re.findall("\[MASK\]", mp_text)) inferred_token = [mp_text] batch["image"][0] = img with torch.no_grad(): for i in range(tl): batch["text"] = inferred_token encoded = tokenizer(inferred_token) batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device) batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device) encoded = encoded["input_ids"][0][1:-1] infer = model(batch) mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1] mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1) mlm_values[torch.tensor(encoded) != 103] = 0 select = mlm_values.argmax().item() encoded[select] = mlm_ids[select].item() inferred_token = [tokenizer.decode(encoded)] selected_token = "" encoded = tokenizer(inferred_token) if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]): with torch.no_grad(): batch["text"] = inferred_token batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device) batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device) infer = model(batch) txt_emb, img_emb = infer["text_feats"], infer["image_feats"] txt_mask, img_mask = ( infer["text_masks"].bool(), infer["image_masks"].bool(), ) for i, _len in enumerate(txt_mask.sum(dim=1)): txt_mask[i, _len - 1] = False txt_mask[:, 0] = False img_mask[:, 0] = False txt_pad, img_pad = ~txt_mask, ~img_mask cost = cost_matrix_cosine(txt_emb.float(), img_emb.float()) joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) cost.masked_fill_(joint_pad, 0) txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to( dtype=cost.dtype ) img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to( dtype=cost.dtype ) T = ipot( cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 0.1, 1000, 1, ) plan = T[0] plan_single = plan * len(txt_emb) cost_ = plan_single.t() cost_ = cost_[hidx][1:].cpu() patch_index, (H, W) = infer["patch_index"] heatmap = torch.zeros(H, W) for i, pidx in enumerate(patch_index[0]): h, w = pidx[0].item(), pidx[1].item() heatmap[h, w] = cost_[i] heatmap = (heatmap - heatmap.mean()) / heatmap.std() heatmap = np.clip(heatmap, 1.0, 3.0) heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) _w, _h = image.size overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize( (_w, _h), resample=Image.NEAREST ) image_rgba = image.copy() image_rgba.putalpha(overlay) image = image_rgba selected_token = tokenizer.convert_ids_to_tokens( encoded["input_ids"][0][hidx] ) return [np.array(image), inferred_token[0], selected_token] inputs = [ gr.inputs.Textbox( label="Url of an image.", lines=5, ), gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5), gr.inputs.Slider( minimum=0, maximum=38, step=1, label="Index of token for heatmap visualization (ignored if zero)", ), ] outputs = [ gr.outputs.Image(label="Image"), gr.outputs.Textbox(label="description"), gr.outputs.Textbox(label="selected token"), ] interface = gr.Interface( fn=infer, inputs=inputs, outputs=outputs, examples=[ [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].", 0, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 4, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 11, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 15, ], [ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.", 18, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a [MASK], a [MASK], a [MASK], and a [MASK].", 0, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 5, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 8, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 11, ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg", "a room with a rug, a chair, a painting, and a plant.", 15, ], ], ) interface.launch(debug=True) ex.run()