Spaces:
Runtime error
Runtime error
File size: 4,023 Bytes
9ad81d2 b14c338 9ad81d2 35d1065 9ad81d2 35d1065 9ad81d2 450eb04 9ad81d2 450eb04 9ad81d2 35d1065 9ad81d2 35d1065 9ad81d2 1ceff34 9ad81d2 1ceff34 9ad81d2 4d766ce 35d1065 1ceff34 35d1065 9ad81d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )
ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu')
new_dict = OrderedDict()
for k, v in ckpt.items():
k = k[len('image_encoder.model.'):]
new_dict.update({k: v})
model = mae_vit_base_patch16(uni_dim=768, less_u=True)
model.load_state_dict(new_dict)
if torch.cuda.is_available():
model.cuda()
model.eval()
@torch.no_grad()
def visual_recon(x, model, mask_ratio=0.75):
target = model.patchify(x)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio)
y, _ = model.forward_decoder(latent, ids_restore)
y = y * (var + 1.e-6)**.5 + mean
y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y)
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask)
x = torch.einsum('nchw->nhwc', x)
return x * (1 - mask), x * (1 - mask) + y * mask, y, latent
@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
assert latent.shape[0] == 1, 'can only caption one image at a time'
x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
seq = x_l.shape[1]
if torch.cuda.is_available():
x_l = x_l.cuda()
cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
x_l = model.embed_text(x_l)
for cross_attn1, cross_attn2 in model.multimodal_layers:
x_l = cross_attn1(x_l, latent)
x_l = cross_attn2(x_l, latent)
pred = model.to_logits(x_l)
pred[:, :, 103] = -100
pred[:, :, 101] = -100
pred[:, :, 100] = -100
pred[:, :, 0] = -100
next_word = pred.argmax(dim=-1)[0, -1]
next_word = tokenizer.decode(next_word)
return next_word
def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
words = prefix.split()
while len(words) < max_len:
next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
words.append(next_word)
if next_word == '[SEP]':
break
return ' '.join(words)
def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'):
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
x = np.array(x) / 255.
x = x - imagenet_mean
x = x / imagenet_std
x = torch.tensor(x).float()
x = x.unsqueeze(0)
x = torch.einsum('nhwc->nchw', x)
if torch.cuda.is_available():
x = x.cuda()
def unnorm_pix(img):
img = img.squeeze(0).cpu().detach().numpy()
img = img * imagenet_std + imagenet_mean
return np.clip(img, a_min=0., a_max=1.)
masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio)
caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix)
masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
return_img = np.concatenate([masked, masked_recon, recon], axis=1)
return return_img, caption_from_model
import gradio as gr
demo = gr.Interface(gr_caption,
inputs=[gr.Image(shape=(224, 224)),
'number',
'number',
'text'],
outputs=[gr.Image(shape=(224, 224 * 3)),
'text'],
examples=[['cat.jpeg', 0.75, 20, 'a photo of a']],)
demo.launch()
|