|
import gradio as gr |
|
import torch |
|
import copy |
|
import time |
|
import requests |
|
import io |
|
import numpy as np |
|
import re |
|
from einops import rearrange |
|
|
|
import ipdb |
|
|
|
from PIL import Image |
|
|
|
from vilt.config import ex |
|
from vilt.modules import ViLTransformerSS |
|
|
|
from vilt.modules.objectives import cost_matrix_cosine, ipot |
|
from vilt.transforms import pixelbert_transform |
|
from vilt.datamodules.datamodule_base import get_pretrained_tokenizer |
|
|
|
|
|
@ex.automain |
|
def main(_config): |
|
_config = copy.deepcopy(_config) |
|
|
|
loss_names = { |
|
"itm": 1, |
|
"mlm": 0.5, |
|
"mpp": 0, |
|
"vqa": 0, |
|
"imgcls": 0, |
|
"nlvr2": 0, |
|
"irtr": 1, |
|
"arc": 0, |
|
} |
|
tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) |
|
|
|
_config.update( |
|
{ |
|
"loss_names": loss_names, |
|
} |
|
) |
|
|
|
model = ViLTransformerSS(_config) |
|
model.setup("test") |
|
model.eval() |
|
|
|
device = "cuda:0" if _config["num_gpus"] > 0 else "cpu" |
|
model.to(device) |
|
lst_imgs = [f"C:\\Users\\alimh\\PycharmProjects\\ViLT\\assets\\database\\{i}.jpg" for i in range(1,10)] |
|
|
|
|
|
def infer( mp_text, hidx =0 ): |
|
def get_image(path): |
|
image = Image.open(path).convert("RGB") |
|
img = pixelbert_transform(size=384)(image) |
|
return img.unsqueeze(0).to(device) |
|
|
|
imgs = [get_image(pth) for pth in lst_imgs] |
|
|
|
batch = [] |
|
for img in imgs: |
|
batch.append({"text": [mp_text], "image": [img]}) |
|
|
|
for dic in batch: |
|
encoded = tokenizer(dic["text"]) |
|
|
|
dic["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) |
|
dic["text_labels"] = torch.tensor(encoded["input_ids"]).to(device) |
|
dic["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device) |
|
|
|
scores = [] |
|
with torch.no_grad(): |
|
|
|
for dic in batch: |
|
s = time.time() |
|
infer = model(dic) |
|
|
|
e = time.time() |
|
print("time ", round(e - s, 2)) |
|
|
|
score = model.rank_output(infer["cls_feats"]) |
|
scores.append(score.item()) |
|
print(scores) |
|
img_idx =np.argmax(scores) |
|
print(np.argmax(scores) + 1 ) |
|
selected_image = Image.open(lst_imgs[img_idx]).convert("RGB") |
|
selected_image = np.asarray(selected_image) |
|
print(selected_image.shape) |
|
selected_token ="" |
|
if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]): |
|
image = Image.open(lst_imgs[img_idx]).convert("RGB") |
|
selected_batch = batch[img_idx] |
|
with torch.no_grad(): |
|
infer = model(selected_batch) |
|
txt_emb, img_emb = infer["text_feats"], infer["image_feats"] |
|
txt_mask, img_mask = ( |
|
infer["text_masks"].bool(), |
|
infer["image_masks"].bool(), |
|
) |
|
for i, _len in enumerate(txt_mask.sum(dim=1)): |
|
txt_mask[i, _len - 1] = False |
|
txt_mask[:, 0] = False |
|
img_mask[:, 0] = False |
|
txt_pad, img_pad = ~txt_mask, ~img_mask |
|
|
|
cost = cost_matrix_cosine(txt_emb.float(), img_emb.float()) |
|
joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) |
|
cost.masked_fill_(joint_pad, 0) |
|
|
|
txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to( |
|
dtype=cost.dtype |
|
) |
|
img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to( |
|
dtype=cost.dtype |
|
) |
|
T = ipot( |
|
cost.detach(), |
|
txt_len, |
|
txt_pad, |
|
img_len, |
|
img_pad, |
|
joint_pad, |
|
0.1, |
|
1000, |
|
1, |
|
) |
|
|
|
plan = T[0] |
|
plan_single = plan * len(txt_emb) |
|
cost_ = plan_single.t() |
|
|
|
cost_ = cost_[hidx][1:].cpu() |
|
|
|
patch_index, (H, W) = infer["patch_index"] |
|
heatmap = torch.zeros(H, W) |
|
for i, pidx in enumerate(patch_index[0]): |
|
h, w = pidx[0].item(), pidx[1].item() |
|
heatmap[h, w] = cost_[i] |
|
|
|
heatmap = (heatmap - heatmap.mean()) / heatmap.std() |
|
heatmap = np.clip(heatmap, 1.0, 3.0) |
|
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
|
|
|
_w, _h = image.size |
|
overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize( |
|
(_w, _h), resample=Image.NEAREST |
|
) |
|
image_rgba = image.copy() |
|
image_rgba.putalpha(overlay) |
|
selected_image = image_rgba |
|
|
|
selected_token = tokenizer.convert_ids_to_tokens( |
|
encoded["input_ids"][0][hidx] |
|
) |
|
|
|
|
|
return [selected_image,hidx] |
|
|
|
imgs = [Image.open(pth).convert("RGB") for pth in lst_imgs] |
|
inputs = [ |
|
|
|
gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5), |
|
gr.inputs.Slider( |
|
minimum=0, |
|
maximum=38, |
|
step=1, |
|
label="Index of token for heatmap visualization (ignored if zero)", |
|
), |
|
] |
|
outputs = [ |
|
gr.outputs.Image(label="Image"), |
|
|
|
|
|
gr.outputs.Textbox(label="matching index "), |
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=infer, |
|
inputs=inputs, |
|
outputs=outputs, |
|
server_name="localhost", |
|
server_port=8888, |
|
|
|
) |
|
|
|
interface.launch(debug=True,share=False) |