MUG_caption / app.py
tennant's picture
add mask
4d766ce
raw
history blame
3.79 kB
import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )
ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu')
new_dict = OrderedDict()
for k, v in ckpt.items():
k = k[len('image_encoder.model.'):]
new_dict.update({k: v})
model = mae_vit_base_patch16(uni_dim=768, less_u=True)
model.load_state_dict(new_dict)
if torch.cuda.is_available():
model.cuda()
model.eval()
@torch.no_grad()
def visual_recon(x, model):
target = model.patchify(x)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=0.75)
y, _ = model.forward_decoder(latent, ids_restore)
y = y * (var + 1.e-6)**.5 + mean
y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y)
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask)
x = torch.einsum('nchw->nhwc', x)
return x * (1 - mask), x * (1 - mask) + y * mask, y, latent
@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
assert latent.shape[0] == 1, 'can only caption one image at a time'
x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
seq = x_l.shape[1]
if torch.cuda.is_available():
x_l = x_l.cuda()
cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
x_l = model.embed_text(x_l)
for cross_attn1, cross_attn2 in model.multimodal_layers:
x_l = cross_attn1(x_l, latent)
x_l = cross_attn2(x_l, latent)
pred = model.to_logits(x_l)
pred[:, :, 103] = -100
pred[:, :, 101] = -100
pred[:, :, 100] = -100
pred[:, :, 0] = -100
next_word = pred.argmax(dim=-1)[0, -1]
next_word = tokenizer.decode(next_word)
return next_word
def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
words = prefix.split()
while len(words) < max_len:
next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
words.append(next_word)
if next_word == '[SEP]':
break
return ' '.join(words)
def gr_caption(x):
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
x = np.array(x) / 255.
x = x - imagenet_mean
x = x / imagenet_std
x = torch.tensor(x).float()
x = x.unsqueeze(0)
x = torch.einsum('nhwc->nchw', x)
if torch.cuda.is_available():
x = x.cuda()
def unnorm_pix(img):
img = img.squeeze(0).cpu().detach().numpy()
img = img * imagenet_std + imagenet_mean
return np.clip(img, a_min=0., a_max=1.)
masked, masked_recon, recon, latent = visual_recon(x, model)
caption_from_model = caption(20, latent, model, tokenizer, )
masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
return masked, masked_recon, recon, caption_from_model
import gradio as gr
demo = gr.Interface(gr_caption,
inputs=[gr.Image(shape=(224, 224))],
outputs=[gr.Image(shape=(224, 224)),
gr.Image(shape=(224, 224)),
gr.Image(shape=(224, 224)),
'text'])
demo.launch()