mae / app.py
Ahsen Khaliq
Update app.py
936caa4
import sys
import os
import requests
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr
os.system("git clone https://github.com/facebookresearch/mae.git")
sys.path.append('./mae')
import models_mae
# define the utils
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(image, title=''):
# image is [H, W, 3]
assert image.shape[2] == 3
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
plt.title(title, fontsize=16)
plt.axis('off')
return
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
# build model
model = getattr(models_mae, arch)()
# load model
checkpoint = torch.load(chkpt_dir, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
print(msg)
return model
def run_one_image(img, model):
x = torch.tensor(img)
# make it a batch-like
x = x.unsqueeze(dim=0)
x = torch.einsum('nhwc->nchw', x)
# run MAE
loss, y, mask = model(x.float(), mask_ratio=0.75)
y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', x)
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(x[0], "original")
plt.subplot(1, 4, 2)
show_image(im_masked[0], "masked")
plt.subplot(1, 4, 3)
show_image(y[0], "reconstruction")
plt.subplot(1, 4, 4)
show_image(im_paste[0], "reconstruction + visible")
plt.savefig("test.png",bbox_inches='tight')
# download checkpoint if not exist
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth")
chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')
def inference(img):
img = img.resize((224, 224))
img = np.array(img) / 255.
assert img.shape == (224, 224, 3)
# normalize by ImageNet mean and std
img = img - imagenet_mean
img = img / imagenet_std
torch.manual_seed(2)
run_one_image(img, model_mae)
return "test.png"
title = "MAE"
description = "Gradio Demo for Masked Autoencoders Are Scalable Vision Learners. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a>| <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_mae' alt='visitor badge'></center>"
examples=[['147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpeg']]
gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging="never",allow_screenshot=False,examples=examples).launch(enable_queue=True)