VITMAE / app.py
merve's picture
merve HF staff
Create app.py
b479d0e
raw
history blame
2.38 kB
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import uuid
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
imagenet_mean = np.array(feature_extractor.image_mean)
imagenet_std = np.array(feature_extractor.image_std)
def show_image(image, title=''):
# image is [H, W, 3]
assert image.shape[2] == 3
unique_id = str(uuid.uuid4())
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
plt.axis('off')
plt.savefig(f"{unique_id}.png", bbox_inches='tight', pad_inches=0)
return f"{unique_id}.png"
def visualize(image):
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
# forward pass
outputs = model(pixel_values)
y = model.unpatchify(outputs.logits)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = outputs.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', pixel_values)
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
gallery_labels = ["Original Image", "Masked Image", "Reconstruction", "Reconstruction with Patches"]
gallery_out = [show_image(out) for out in [x[0], im_masked[0], y[0], im_paste[0]]]
return [(k,v) for k,v in zip(gallery_out, gallery_labels)]
with gr.Blocks() as demo:
gr.Markdown("## ViTMAE Demo")
gr.Markdown("ViTMAE is an architecture that combine masked autoencoder and Vision Transformer (ViT) for self-supervised pre-training.")
gr.Markdown("By pre-training a ViT to reconstruct pixel values for masked patches, one can get results after fine-tuning that outperform supervised pre-training.")
with gr.Row():
input_img = gr.Image()
output = gr.Gallery()
input_img.change(visualize, inputs=input_img, outputs=output)
gr.Examples([["./cat.png"]], inputs=input_img, outputs=output, fn=visualize)
demo.launch(debug=True)