Andrei Boiarov
v1 version of app
14bb247
raw
history blame contribute delete
No virus
2.11 kB
from transformers import ViTMAEForPreTraining, ViTImageProcessor
import numpy as np
import torch
import gradio as gr
image_processor = ViTImageProcessor.from_pretrained('andrewbo29/vit-mae-base-formula1')
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
imagenet_mean = np.array(image_processor.image_mean)
imagenet_std = np.array(image_processor.image_std)
def prep_image(image):
return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
def reconstruct(image):
pixel_values = image_processor.preprocess(image, return_tensors='pt').pixel_values
outputs = model(pixel_values)
y = model.unpatchify(outputs.logits)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = outputs.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
out_orig = prep_image(x[0])
out_masked = prep_image(im_masked[0])
out_rec = prep_image(y[0])
out_rec_vis = prep_image(im_paste[0])
return [(out_orig, 'original'),
(out_masked, 'masked'),
(out_rec, 'reconstruction'),
(out_rec_vis, 'reconstruction + visible')]
with gr.Blocks() as demo:
with gr.Column(variant='panel'):
with gr.Column():
img = gr.Image(
container=False,
type='pil'
)
btn = gr.Button(
'Apply F1 MAE',
scale=0
)
gallery = gr.Gallery(
columns=4,
rows=1,
height='300px',
object_fit='none'
)
btn.click(reconstruct, img, gallery)
if __name__ == "__main__":
demo.launch()