File size: 2,579 Bytes
98d4fbe
dba7cbf
5066aaa
 
98d4fbe
 
 
 
 
 
 
5066aaa
 
98d4fbe
5066aaa
fa0722e
5066aaa
98d4fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5066aaa
e0f0523
98d4fbe
 
 
 
 
 
fa0722e
 
98d4fbe
fa0722e
98d4fbe
88ed06b
94a816d
fa0722e
 
 
98d4fbe
fa0722e
 
dc28d34
0e9f6ff
dc28d34
56dff7d
fa0722e
7a28a77
fa0722e
0ca4068
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
import os

import matplotlib.pyplot as plt
import PIL
from PIL import Image
import json

import torch
import torchvision
import torchvision.transforms as T


from timm import create_model

import gradio as gr


model_name = "convnext_xlarge_in22k"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create a ConvNeXt model : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py
model = create_model(model_name, pretrained=True).to(device)

# Define transforms for test
from timm.data.constants import \
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

NORMALIZE_MEAN = IMAGENET_DEFAULT_MEAN
NORMALIZE_STD = IMAGENET_DEFAULT_STD
SIZE = 256

# Here we resize smaller edge to 256, no center cropping
transforms = [
              T.Resize(SIZE, interpolation=T.InterpolationMode.BICUBIC),
              T.ToTensor(),
              T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
              ]

transforms = T.Compose(transforms)

os.system("wget https://dl.fbaipublicfiles.com/convnext/label_to_words.json")
imagenet_labels = json.load(open('label_to_words.json'))

def inference(img):
    img_tensor = transforms(img).unsqueeze(0).to(device)
    # inference
    output = torch.softmax(model(img_tensor), dim=1)
    top5 = torch.topk(output, k=5)
    top5_prob = top5.values[0]
    top5_indices = top5.indices[0]
    
    result = {}

    for i in range(5):
        labels = imagenet_labels[str(int(top5_indices[i]))]
        prob = float(top5_prob[i])
        result[labels] = prob
    
    return result

inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)

title = "ConvNeXt"
description = "Gradio demo for ConvNeXt for image classification. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03545' target='_blank'>A ConvNet for the 2020s</a> | <a href='https://github.com/facebookresearch/ConvNeXt' target='_blank'>Github Repo</a> | <a href='https://github.com/leondgarse/keras_cv_attention_models' target='_blank'>pretrained ConvNeXt model from keras_cv_attention_models</a> | <a href='https://github.com/stanislavfort/adversaries_to_convnext' target='_blank'>examples usage from adversaries_to_convnext</a></p>"

examples = ['Tortoise-on-ground-surrounded-by-plants.jpeg']

gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True)