File size: 2,112 Bytes
14bb247
581506e
 
3d04a5c
 
14bb247
581506e
 
14bb247
 
581506e
 
 
 
859c3ef
581506e
14bb247
 
7440015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bb247
7440015
 
 
 
14bb247
 
 
 
859c3ef
 
 
14bb247
 
581506e
859c3ef
14bb247
 
 
 
 
859c3ef
581506e
14bb247
 
 
 
 
 
893c03b
581506e
3d04a5c
859c3ef
aa94b73
 
581506e
aa94b73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import ViTMAEForPreTraining, ViTImageProcessor
import numpy as np
import torch
import gradio as gr

image_processor = ViTImageProcessor.from_pretrained('andrewbo29/vit-mae-base-formula1')
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')

imagenet_mean = np.array(image_processor.image_mean)
imagenet_std = np.array(image_processor.image_std)


def prep_image(image):
    return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()


def reconstruct(image):
    pixel_values = image_processor.preprocess(image, return_tensors='pt').pixel_values

    outputs = model(pixel_values)
    y = model.unpatchify(outputs.logits)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = outputs.mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    out_orig = prep_image(x[0])
    out_masked = prep_image(im_masked[0])
    out_rec = prep_image(y[0])
    out_rec_vis = prep_image(im_paste[0])

    return [(out_orig, 'original'),
            (out_masked, 'masked'),
            (out_rec, 'reconstruction'),
            (out_rec_vis, 'reconstruction + visible')]


with gr.Blocks() as demo:
    with gr.Column(variant='panel'):
        with gr.Column():
            img = gr.Image(
                container=False,
                type='pil'
            )
            btn = gr.Button(
                'Apply F1 MAE',
                 scale=0
            )

            gallery = gr.Gallery(
                columns=4,
                rows=1,
                height='300px',
                object_fit='none'
            )

    btn.click(reconstruct, img, gallery)

if __name__ == "__main__":
    demo.launch()