akhaliq HF staff commited on
Commit
1c804e6
1 Parent(s): 1b78b3e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ import gradio as gr
5
+ import os
6
+
7
+ import torch
8
+ # load WRN-50-2:
9
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)
10
+ # or WRN-101-2
11
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet101_2', pretrained=True)
12
+ model.eval()
13
+
14
+ os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
15
+
16
+ torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
17
+
18
+ def inference(input_image):
19
+
20
+ preprocess = transforms.Compose([
21
+ transforms.Resize(256),
22
+ transforms.CenterCrop(224),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
25
+ ])
26
+ input_tensor = preprocess(input_image)
27
+ input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
28
+
29
+ # move the input and model to GPU for speed if available
30
+ if torch.cuda.is_available():
31
+ input_batch = input_batch.to('cuda')
32
+ model.to('cuda')
33
+
34
+ with torch.no_grad():
35
+ output = model(input_batch)
36
+ # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
37
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
38
+
39
+ # Read the categories
40
+ with open("imagenet_classes.txt", "r") as f:
41
+ categories = [s.strip() for s in f.readlines()]
42
+ # Show top categories per image
43
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
44
+ result = {}
45
+ for i in range(top5_prob.size(0)):
46
+ result[categories[top5_catid[i]]] = top5_prob[i].item()
47
+ return result
48
+
49
+ inputs = gr.inputs.Image(type='pil')
50
+ outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
51
+
52
+ title = "Wide_Resnet"
53
+ description = "Gradio demo for Wide Resnet, Wide Residual Networks. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
54
+
55
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1605.07146'>Wide Residual Networks</a> | <a href='https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py'>Github Repo</a></p>"
56
+
57
+ examples = [
58
+ ['dog.jpg']
59
+ ]
60
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()