import sys import os import requests import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image import gradio as gr os.system("git clone https://github.com/facebookresearch/mae.git") sys.path.append('./mae') import models_mae # define the utils imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) def show_image(image, title=''): # image is [H, W, 3] assert image.shape[2] == 3 plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) plt.title(title, fontsize=16) plt.axis('off') return def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): # build model model = getattr(models_mae, arch)() # load model checkpoint = torch.load(chkpt_dir, map_location='cpu') msg = model.load_state_dict(checkpoint['model'], strict=False) print(msg) return model def run_one_image(img, model): x = torch.tensor(img) # make it a batch-like x = x.unsqueeze(dim=0) x = torch.einsum('nhwc->nchw', x) # run MAE loss, y, mask = model(x.float(), mask_ratio=0.75) y = model.unpatchify(y) y = torch.einsum('nchw->nhwc', y).detach().cpu() # visualize the mask mask = mask.detach() mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) mask = model.unpatchify(mask) # 1 is removing, 0 is keeping mask = torch.einsum('nchw->nhwc', mask).detach().cpu() x = torch.einsum('nchw->nhwc', x) # masked image im_masked = x * (1 - mask) # MAE reconstruction pasted with visible patches im_paste = x * (1 - mask) + y * mask # make the plt figure larger plt.rcParams['figure.figsize'] = [24, 6] plt.subplot(1, 4, 1) show_image(x[0], "original") plt.subplot(1, 4, 2) show_image(im_masked[0], "masked") plt.subplot(1, 4, 3) show_image(y[0], "reconstruction") plt.subplot(1, 4, 4) show_image(im_paste[0], "reconstruction + visible") plt.savefig("test.png",bbox_inches='tight') # download checkpoint if not exist os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth") chkpt_dir = 'mae_visualize_vit_large.pth' model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') print('Model loaded.') def inference(img): img = img.resize((224, 224)) img = np.array(img) / 255. assert img.shape == (224, 224, 3) # normalize by ImageNet mean and std img = img - imagenet_mean img = img / imagenet_std torch.manual_seed(2) run_one_image(img, model_mae) return "test.png" title = "MAE" description = "Gradio Demo for Masked Autoencoders Are Scalable Vision Learners. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "

Masked Autoencoders Are Scalable Vision Learners| Github Repo

visitor badge
" examples=[['147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpeg']] gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging="never",allow_screenshot=False,examples=examples).launch(enable_queue=True)