akhaliq HF staff commited on
Commit
8bc2e9b
1 Parent(s): f50fe13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import torchvision.transforms as transforms
4
+
5
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
6
+
7
+ resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
8
+ utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')
9
+
10
+ resnet50.eval().to(device)
11
+
12
+ def inference(img):
13
+
14
+ img_transforms = transforms.Compose(
15
+ [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
16
+ )
17
+
18
+ img = img_transforms(img)
19
+ with torch.no_grad():
20
+ # mean and std are not multiplied by 255 as they are in training script
21
+ # torch dataloader reads data into bytes whereas loading directly
22
+ # through PIL creates a tensor with floats in [0,1] range
23
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
24
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
25
+ img = img.float()
26
+ img = img.unsqueeze(0).sub_(mean).div_(std)
27
+
28
+ batch = torch.cat(
29
+ [img]
30
+ ).to(device)
31
+
32
+ with torch.no_grad():
33
+ output = torch.nn.functional.softmax(resnet50(batch), dim=1)
34
+
35
+ results = utils.pick_n_best(predictions=output, n=5)
36
+
37
+ return results
38
+
39
+ title="ResNet50"
40
+ description="Gradio demo for ResNet50, ResNet50 model trained with mixed precision using Tensor Cores. To use it, simply upload your image or click on one of the examples below. Read more at the links below"
41
+
42
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385'>Deep Residual Learning for Image Recognition</a> | <a href='https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5'>Github Repo</a></p>"
43
+
44
+ examples=[['food.jpeg']]
45
+ gr.Interface(inference,gr.inputs.Image(type="pil"),"text",title=title,description=description,article=article,examples=examples).launch(enable_queue=True)