import sys import os import requests import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image import gradio as gr os.system("pip install timm==0.4.5") os.system("git clone https://github.com/facebookresearch/mae.git") sys.path.append('./mae') import models_mae # define the utils imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) def show_image(image, title=''): # image is [H, W, 3] assert image.shape[2] == 3 plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) plt.title(title, fontsize=16) plt.axis('off') return def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): # build model model = getattr(models_mae, arch)() # load model checkpoint = torch.load(chkpt_dir, map_location='cpu') msg = model.load_state_dict(checkpoint['model'], strict=False) print(msg) return model def run_one_image(img, model): x = torch.tensor(img) # make it a batch-like x = x.unsqueeze(dim=0) x = torch.einsum('nhwc->nchw', x) # run MAE loss, y, mask = model(x.float(), mask_ratio=0.75) y = model.unpatchify(y) y = torch.einsum('nchw->nhwc', y).detach().cpu() # visualize the mask mask = mask.detach() mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) mask = model.unpatchify(mask) # 1 is removing, 0 is keeping mask = torch.einsum('nchw->nhwc', mask).detach().cpu() x = torch.einsum('nchw->nhwc', x) # masked image im_masked = x * (1 - mask) # MAE reconstruction pasted with visible patches im_paste = x * (1 - mask) + y * mask # make the plt figure larger plt.rcParams['figure.figsize'] = [24, 24] plt.subplot(1, 4, 1) show_image(x[0], "original") plt.subplot(1, 4, 2) show_image(im_masked[0], "masked") plt.subplot(1, 4, 3) show_image(y[0], "reconstruction") plt.subplot(1, 4, 4) show_image(im_paste[0], "reconstruction + visible") plt.show() # download checkpoint if not exist os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth") chkpt_dir = 'mae_visualize_vit_large.pth' model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') print('Model loaded.') def inference(img): img = img.resize((224, 224)) img = np.array(img) / 255. assert img.shape == (224, 224, 3) # normalize by ImageNet mean and std img = img - imagenet_mean img = img / imagenet_std torch.manual_seed(2) return run_one_image(img, model_mae) title = "MAE" description = "Gradio Demo for MAE. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "

JoJoGAN: One Shot Face Stylization| Github Repo Pytorch

visitor badge
" gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="plot"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=True).launch()