File size: 3,505 Bytes
3c6d2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d07a1
3c6d2ed
 
 
 
 
 
 
 
 
 
 
 
ce94770
44d61ed
ce94770
3c6d2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7267a03
44d61ed
 
3c6d2ed
 
 
712e595
3c6d2ed
4f19843
3c6d2ed
936caa4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr


os.system("git clone https://github.com/facebookresearch/mae.git")
sys.path.append('./mae')

import models_mae

# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 6]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")
    
    plt.savefig("test.png",bbox_inches='tight')
    

# download checkpoint if not exist
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth")

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')

    
def inference(img):    
  img = img.resize((224, 224))
  img = np.array(img) / 255.
  
  assert img.shape == (224, 224, 3)
  
  # normalize by ImageNet mean and std
  img = img - imagenet_mean
  img = img / imagenet_std
  
  
  torch.manual_seed(2)

  run_one_image(img, model_mae)
  return "test.png"
  

title = "MAE"
description = "Gradio Demo for Masked Autoencoders Are Scalable Vision Learners. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a>| <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_mae' alt='visitor badge'></center>"

examples=[['147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpeg']]
gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging="never",allow_screenshot=False,examples=examples).launch(enable_queue=True)