import sys import os import requests import torch from PIL import Image from torchvision import transforms import gradio as gr # timm==0.4.5 # 0.3.2 does not work in Colab os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt") os.system("git clone https://github.com/facebookresearch/mae.git") sys.path.append('./mae') import models_mae import models_vit def prepare_model(chkpt_dir, arch='vit_large_patch14'): # build model model = getattr(models_vit, arch)(global_pool=True) # load model checkpoint = torch.load(chkpt_dir, map_location='cpu') msg = model.load_state_dict(checkpoint['model'], strict=True) print(msg) return model def inference(input_image): preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model # move the input and model to GPU for speed if available if torch.cuda.is_available(): input_batch = input_batch.to('cuda') model.to('cuda') with torch.no_grad(): output = model(input_batch) # The output has unnormalized scores. To get probabilities, you can run a softmax on it. probabilities = torch.nn.functional.softmax(output[0], dim=0) # Read the categories with open("imagenet_classes.txt", "r") as f: categories = [s.strip() for s in f.readlines()] # Show top categories per image top5_prob, top5_catid = torch.topk(probabilities, 5) result = {} for i in range(top5_prob.size(0)): result[categories[top5_catid[i]]] = top5_prob[i].item() return result os.system("wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_large.pth") chkpt_dir = 'mae_finetuned_vit_large.pth' model = prepare_model(chkpt_dir, 'vit_large_patch16') # Download an example image from the pytorch website torch.hub.download_url_to_file("https://estaticos.megainteresting.com/media/cache/1140x_thumb/uploads/images/gallery/5e7c585f5cafe8134048af67/gato-persa-gris_0.jpg", "persian_cat.jpg") torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg", "fox.jpg") torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg", "cucumber.jpg") inputs = gr.inputs.Image(type='pil') outputs = gr.outputs.Label(type="confidences",num_top_classes=5) title = "MAE" description = "Gradio demo for Masked Autoencoders (MAE) ImageNet classification (large-patch16). To use it, simply upload your image, or click on the examples to load them. Read more at the links below." article = "

Masked Autoencoders Are Scalable Vision Learners | Github Repo

" examples = [ ['persian_cat.jpg'], ['fox.jpg'], ['cucumber.jpg'] ] gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()