File size: 3,415 Bytes
83a52ac 6cb096c d68cbdc 4a006f9 83a52ac 4a006f9 83a52ac 4a006f9 83a52ac b982d7b d68cbdc cc0a8e8 83a52ac f071537 83a52ac f071537 83a52ac b982d7b d68cbdc b982d7b d68cbdc b982d7b d68cbdc b982d7b d68cbdc 83a52ac 4a006f9 cc0a8e8 83a52ac 9feccaa 3826e54 b982d7b d68cbdc b982d7b 725253c b8e0a7b b982d7b 199ea21 b982d7b d68cbdc 9feccaa 84e0927 d68cbdc b388ddb d68cbdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import sys
import os
import requests
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
# timm==0.4.5 # 0.3.2 does not work in Colab
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
os.system("git clone https://github.com/facebookresearch/mae.git")
sys.path.append('./mae')
import models_mae
import models_vit
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
# build model
model = getattr(models_vit, arch)(global_pool=True)
# load model
checkpoint = torch.load(chkpt_dir, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=True)
print(msg)
return model
def inference(input_image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Read the categories
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
result = {}
for i in range(top5_prob.size(0)):
result[categories[top5_catid[i]]] = top5_prob[i].item()
return result
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_large.pth")
chkpt_dir = 'mae_finetuned_vit_large.pth'
model = prepare_model(chkpt_dir, 'vit_large_patch16')
# Download an example image from the pytorch website
torch.hub.download_url_to_file("https://estaticos.megainteresting.com/media/cache/1140x_thumb/uploads/images/gallery/5e7c585f5cafe8134048af67/gato-persa-gris_0.jpg", "persian_cat.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg", "fox.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg", "cucumber.jpg")
inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
title = "MAE"
description = "Gradio demo for Masked Autoencoders (MAE) ImageNet classification (large-patch16). To use it, simply upload your image, or click on the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a> | <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p>"
examples = [
['persian_cat.jpg'],
['fox.jpg'],
['cucumber.jpg']
]
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch() |