File size: 3,415 Bytes
83a52ac
 
 
6cb096c
d68cbdc
4a006f9
 
83a52ac
 
4a006f9
83a52ac
4a006f9
83a52ac
 
 
 
 
 
b982d7b
d68cbdc
cc0a8e8
83a52ac
f071537
83a52ac
 
f071537
83a52ac
 
 
 
b982d7b
d68cbdc
 
 
 
 
 
 
 
 
b982d7b
d68cbdc
 
 
 
b982d7b
d68cbdc
 
 
 
b982d7b
d68cbdc
 
 
 
 
 
 
 
 
83a52ac
 
4a006f9
cc0a8e8
 
 
83a52ac
 
 
9feccaa
 
3826e54
b982d7b
d68cbdc
 
b982d7b
725253c
b8e0a7b
b982d7b
199ea21
b982d7b
d68cbdc
9feccaa
84e0927
 
d68cbdc
b388ddb
d68cbdc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import os
import requests
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr


#  timm==0.4.5  # 0.3.2 does not work in Colab

os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
os.system("git clone https://github.com/facebookresearch/mae.git")
sys.path.append('./mae')

import models_mae
import models_vit



def prepare_model(chkpt_dir, arch='vit_large_patch14'):
    # build model
    model = getattr(models_vit, arch)(global_pool=True)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=True)
    print(msg)
    return model



def inference(input_image):
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Read the categories
    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    result = {}
    for i in range(top5_prob.size(0)):
        result[categories[top5_catid[i]]] = top5_prob[i].item()
    return result
    


os.system("wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_large.pth")
chkpt_dir = 'mae_finetuned_vit_large.pth'
model = prepare_model(chkpt_dir, 'vit_large_patch16')


# Download an example image from the pytorch website
torch.hub.download_url_to_file("https://estaticos.megainteresting.com/media/cache/1140x_thumb/uploads/images/gallery/5e7c585f5cafe8134048af67/gato-persa-gris_0.jpg", "persian_cat.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg", "fox.jpg")
torch.hub.download_url_to_file("https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg", "cucumber.jpg")

inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)

title = "MAE"
description = "Gradio demo for Masked Autoencoders (MAE) ImageNet classification (large-patch16). To use it, simply upload your image, or click on the examples to load them. Read more at the links below."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a> | <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p>"

examples = [
            ['persian_cat.jpg'],
            ['fox.jpg'],
            ['cucumber.jpg']            
]

gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()