|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor |
|
from PIL import Image |
|
import uuid |
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base") |
|
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") |
|
|
|
imagenet_mean = np.array(feature_extractor.image_mean) |
|
imagenet_std = np.array(feature_extractor.image_std) |
|
|
|
def show_image(image, title=''): |
|
|
|
assert image.shape[2] == 3 |
|
unique_id = str(uuid.uuid4()) |
|
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) |
|
plt.axis('off') |
|
plt.savefig(f"{unique_id}.png", bbox_inches='tight', pad_inches=0) |
|
|
|
return f"{unique_id}.png" |
|
|
|
def visualize(image): |
|
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values |
|
|
|
outputs = model(pixel_values) |
|
y = model.unpatchify(outputs.logits) |
|
y = torch.einsum('nchw->nhwc', y).detach().cpu() |
|
|
|
|
|
mask = outputs.mask.detach() |
|
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3) |
|
mask = model.unpatchify(mask) |
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu() |
|
|
|
x = torch.einsum('nchw->nhwc', pixel_values) |
|
|
|
|
|
im_masked = x * (1 - mask) |
|
|
|
|
|
im_paste = x * (1 - mask) + y * mask |
|
|
|
gallery_labels = ["Original Image", "Masked Image", "Reconstruction", "Reconstruction with Patches"] |
|
gallery_out = [show_image(out) for out in [x[0], im_masked[0], y[0], im_paste[0]]] |
|
|
|
return [(k,v) for k,v in zip(gallery_out, gallery_labels)] |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## ViTMAE Demo") |
|
gr.Markdown("ViTMAE is an architecture that combine masked autoencoder and Vision Transformer (ViT) for self-supervised pre-training.") |
|
gr.Markdown("By pre-training a ViT to reconstruct pixel values for masked patches, one can get results after fine-tuning that outperform supervised pre-training.") |
|
|
|
with gr.Row(): |
|
|
|
input_img = gr.Image() |
|
output = gr.Gallery() |
|
|
|
input_img.change(visualize, inputs=input_img, outputs=output) |
|
|
|
gr.Examples([["./cat.png"]], inputs=input_img, outputs=output, fn=visualize) |
|
|
|
demo.launch(debug=True) |
|
|