File size: 1,207 Bytes
5814591
 
 
 
 
 
 
4308ef9
5814591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95249fb
 
5814591
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms

model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
                       **{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval()

transform_test = transforms.Compose([
    transforms.Resize((600, 600), Image.BILINEAR),
    transforms.CenterCrop((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def predict(inp):
    inp = Image.fromarray(inp.astype('uint8'), 'RGB')
    scaled_img = transform_test(inp)
    torch_images = scaled_img.unsqueeze(0)

    with torch.no_grad():
        top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(
            torch_images)
    pred = torch.nn.functional.softmax(concat_logits)
    return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))}


inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=10)
gr.Interface(fn=predict, inputs=inputs, outputs=outputs,
             title="200 Bird Species Classifications with NTS-NET (From CUB 200)",
             examples=["gabbiano.jpg"]).launch()