File size: 1,971 Bytes
2460f60
 
 
 
 
 
d9672a2
2460f60
 
d9672a2
2460f60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9672a2
2460f60
 
 
 
 
d9672a2
 
2460f60
d9672a2
2460f60
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import gradio as gr
import torchvision.transforms as transforms

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

resneXt = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resneXt')
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')

resneXt.eval().to(device)

def inference(img):
  img_transforms = transforms.Compose(
                [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
  )
  img = img_transforms(img)
  with torch.no_grad():
    # mean and std are not multiplied by 255 as they are in training script
    # torch dataloader reads data into bytes whereas loading directly
    # through PIL creates a tensor with floats in [0,1] range
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    img = img.float()
    img = img.unsqueeze(0).sub_(mean).div_(std)

  batch = torch.cat(
    [img]
  ).to(device)
 
  with torch.no_grad():
    output = torch.nn.functional.softmax(resneXt(batch), dim=1)
    
  results = utils.pick_n_best(predictions=output, n=5)
  
  return results
  
title="ResNeXt101"
description="Gradio demo for ResNeXt101, ResNet with bottleneck 3x3 Convolutions substituted by 3x3 Grouped Convolutions, trained with mixed precision using Tensor Cores. To use it, simply upload your image or click on one of the examples below. Read more at the links below"

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1611.05431'>Aggregated Residual Transformations for Deep Neural Networks</a> | <a href='https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnext101-32x4d'>Github Repo</a></p>"

examples=[['food.jpeg']]
gr.Interface(inference,gr.inputs.Image(type="pil"),"text",title=title,description=description,article=article,examples=examples).launch(enable_queue=True)