File size: 4,023 Bytes
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
b14c338
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
 
 
35d1065
9ad81d2
 
 
 
35d1065
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450eb04
 
 
 
9ad81d2
 
 
 
 
 
 
 
 
 
450eb04
 
9ad81d2
 
 
35d1065
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35d1065
 
9ad81d2
 
1ceff34
9ad81d2
1ceff34
9ad81d2
 
 
4d766ce
35d1065
 
 
 
1ceff34
35d1065
 
9ad81d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )

ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu')

new_dict = OrderedDict()
for k, v in ckpt.items():
    k = k[len('image_encoder.model.'):]
    new_dict.update({k: v})

model = mae_vit_base_patch16(uni_dim=768, less_u=True)

model.load_state_dict(new_dict)
if torch.cuda.is_available():
    model.cuda()
model.eval()

@torch.no_grad()
def visual_recon(x, model, mask_ratio=0.75):
    target = model.patchify(x)
    mean = target.mean(dim=-1, keepdim=True)
    var = target.var(dim=-1, keepdim=True)

    latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio)
    y, _ = model.forward_decoder(latent, ids_restore)
    y = y * (var + 1.e-6)**.5 + mean
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y)
    
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask)
    
    x = torch.einsum('nchw->nhwc', x)
    
    return x * (1 - mask), x * (1 - mask) + y * mask, y, latent

@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
    assert latent.shape[0] == 1, 'can only caption one image at a time'
    
    x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
    seq = x_l.shape[1]
    if torch.cuda.is_available():
        x_l = x_l.cuda()

    cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

    x_l = model.embed_text(x_l)

    for cross_attn1, cross_attn2 in model.multimodal_layers:
        x_l = cross_attn1(x_l, latent)
        x_l = cross_attn2(x_l, latent)

    pred = model.to_logits(x_l)
    pred[:, :, 103] = -100
    pred[:, :, 101] = -100
    pred[:, :, 100] = -100
    pred[:, :, 0] = -100
    next_word = pred.argmax(dim=-1)[0, -1]
    next_word = tokenizer.decode(next_word)
    
    return next_word

def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
    words = prefix.split()
    while len(words) < max_len:
        next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
        words.append(next_word)
        if next_word == '[SEP]':
            break
    return ' '.join(words)


def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'):
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    x = np.array(x) / 255.
    x = x - imagenet_mean
    x = x / imagenet_std

    x = torch.tensor(x).float()
    x = x.unsqueeze(0)
    x = torch.einsum('nhwc->nchw', x)
    if torch.cuda.is_available():
        x = x.cuda()
        
    def unnorm_pix(img):
        img = img.squeeze(0).cpu().detach().numpy()
        img = img * imagenet_std + imagenet_mean
        return np.clip(img, a_min=0., a_max=1.)

    masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio)
    caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix)
    
    masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
    return_img = np.concatenate([masked, masked_recon, recon], axis=1)
    
    return return_img, caption_from_model

import gradio as gr

demo = gr.Interface(gr_caption, 
                    inputs=[gr.Image(shape=(224, 224)),
                            'number',
                            'number',
                            'text'], 
                    outputs=[gr.Image(shape=(224, 224 * 3)), 
                             'text'],
                    examples=[['cat.jpeg', 0.75, 20, 'a photo of a']],)
demo.launch()