Andrei Boiarov
Update app file
859c3ef
raw
history blame
No virus
2.07 kB
from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
import numpy as np
import torch
from PIL import Image
import gradio as gr
feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
imagenet_mean = np.array(feature_extractor.image_mean)
imagenet_std = np.array(feature_extractor.image_std)
def prep_image(image):
return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
def reconstruct(img):
image = Image.fromarray(img)
pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
outputs = model(pixel_values)
y = model.unpatchify(outputs.logits)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = outputs.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
out_masked = prep_image(im_masked[0])
out_rec = prep_image(y[0])
out_rec_vis = prep_image(im_paste[0])
return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row():
img = gr.Image(
label="Enter your prompt",
container=False,
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height='auto', container=True)
btn.click(reconstruct, img, gallery)
if __name__ == "__main__":
demo.launch()