z-uo commited on
Commit
5814591
1 Parent(s): 964f8e3

add app.py

Browse files
Files changed (2) hide show
  1. app.py +32 -0
  2. requirements.txt +42 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
7
+ **{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval()
8
+ transform_test = transforms.Compose([
9
+ transforms.Resize((600, 600), Image.BILINEAR),
10
+ transforms.CenterCrop((448, 448)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
13
+ ])
14
+
15
+
16
+ def predict(inp):
17
+ inp = Image.fromarray(inp.astype('uint8'), 'RGB')
18
+ scaled_img = transform_test(inp)
19
+ torch_images = scaled_img.unsqueeze(0)
20
+
21
+ with torch.no_grad():
22
+ top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(
23
+ torch_images)
24
+ pred = torch.nn.functional.softmax(concat_logits)
25
+ return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))}
26
+
27
+
28
+ inputs = gr.inputs.Image()
29
+ outputs = gr.outputs.Label(num_top_classes=10)
30
+ gr.Interface(fn=predict, inputs=inputs, outputs=outputs,
31
+ title="200 Bird Species Classifications with NTS-NET (From CUB 200)").launch()
32
+
requirements.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ analytics-python==1.4.0
2
+ backoff==1.10.0
3
+ bcrypt==3.2.0
4
+ certifi==2021.10.8
5
+ cffi==1.15.0
6
+ charset-normalizer==2.0.7
7
+ click==8.0.3
8
+ cryptography==35.0.0
9
+ cycler==0.11.0
10
+ ffmpy==0.3.0
11
+ Flask==2.0.2
12
+ Flask-CacheBuster==1.0.0
13
+ Flask-Cors==3.0.10
14
+ Flask-Login==0.5.0
15
+ gradio==2.4.1
16
+ idna==3.3
17
+ itsdangerous==2.0.1
18
+ Jinja2==3.0.2
19
+ kiwisolver==1.3.2
20
+ markdown2==2.4.1
21
+ MarkupSafe==2.0.1
22
+ matplotlib==3.4.3
23
+ monotonic==1.6
24
+ numpy==1.21.3
25
+ pandas==1.3.4
26
+ paramiko==2.8.0
27
+ Pillow==8.4.0
28
+ pkg_resources==0.0.0
29
+ pycparser==2.20
30
+ pycryptodome==3.11.0
31
+ pydub==0.25.1
32
+ PyNaCl==1.4.0
33
+ pyparsing==3.0.3
34
+ python-dateutil==2.8.2
35
+ pytz==2021.3
36
+ requests==2.26.0
37
+ six==1.16.0
38
+ torch==1.10.0
39
+ torchvision==0.11.1
40
+ typing-extensions==3.10.0.2
41
+ urllib3==1.26.7
42
+ Werkzeug==2.0.2