Ahsen Khaliq commited on
Commit
aad70c7
1 Parent(s): 931df05

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import torch
3
+ import urllib
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import torch
7
+
8
+ # Images
9
+ torch.hub.download_url_to_file('https://images.pexels.com/photos/17811/pexels-photo.jpg', 'bird.jpg')
10
+
11
+ model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
12
+ **{'topN': 6, 'device':'cpu', 'num_classes': 200})
13
+
14
+ transform_test = transforms.Compose([
15
+ transforms.Resize((600, 600), Image.BILINEAR),
16
+ transforms.CenterCrop((448, 448)),
17
+ # transforms.RandomHorizontalFlip(), # only if train
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
20
+ ])
21
+
22
+
23
+ model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200})
24
+
25
+ def birds(img):
26
+ scaled_img = transform_test(img)
27
+ torch_images = scaled_img.unsqueeze(0)
28
+
29
+ with torch.no_grad():
30
+ top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
31
+
32
+ _, predict = torch.max(concat_logits, 1)
33
+ pred_id = predict.item()
34
+ return model.bird_classes[pred_id].split('.')[1]
35
+
36
+ inputs = gr.inputs.Image(type='pil', label="Original Image")
37
+ outputs = gr.outputs.Textbox(label="bird class")
38
+
39
+ title = "ntsnet"
40
+ description = "demo for ntsnet to classify birds. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
41
+ article = "<p style='text-align: center'><a href='http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf'>Are These Birds Similar: Learning Branched Networks for Fine-grained Representations</a> | <a href='https://github.com/nicolalandro/ntsnet-cub200'>Github Repo</a></p>"
42
+
43
+ examples = [
44
+ ['bird.jpg']
45
+ ]
46
+ gr.Interface(birds, inputs, outputs, title=title, description=description,
47
+ article=article, examples=examples, analytics_enabled=False).launch()