MAE-pred / app.py
Ehsa's picture
Update app.py
f071537
raw
history blame contribute delete
No virus
3.42 kB
import sys
import os
import requests
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
# timm==0.4.5 # 0.3.2 does not work in Colab
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
os.system("git clone https://github.com/facebookresearch/mae.git")
sys.path.append('./mae')
import models_mae
import models_vit
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
# build model
model = getattr(models_vit, arch)(global_pool=True)
# load model
checkpoint = torch.load(chkpt_dir, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=True)
print(msg)
return model
def inference(input_image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Read the categories
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
result = {}
for i in range(top5_prob.size(0)):
result[categories[top5_catid[i]]] = top5_prob[i].item()
return result
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_large.pth")
chkpt_dir = 'mae_finetuned_vit_large.pth'
model = prepare_model(chkpt_dir, 'vit_large_patch16')
# Download an example image from the pytorch website
torch.hub.download_url_to_file("https://estaticos.megainteresting.com/media/cache/1140x_thumb/uploads/images/gallery/5e7c585f5cafe8134048af67/gato-persa-gris_0.jpg", "persian_cat.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg", "fox.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg", "cucumber.jpg")
inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
title = "MAE"
description = "Gradio demo for Masked Autoencoders (MAE) ImageNet classification (large-patch16). To use it, simply upload your image, or click on the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a> | <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p>"
examples = [
['persian_cat.jpg'],
['fox.jpg'],
['cucumber.jpg']
]
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()