Spaces:
Runtime error
Runtime error
add app.py
Browse files- app.py +32 -0
- requirements.txt +42 -0
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
|
7 |
+
**{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval()
|
8 |
+
transform_test = transforms.Compose([
|
9 |
+
transforms.Resize((600, 600), Image.BILINEAR),
|
10 |
+
transforms.CenterCrop((448, 448)),
|
11 |
+
transforms.ToTensor(),
|
12 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
13 |
+
])
|
14 |
+
|
15 |
+
|
16 |
+
def predict(inp):
|
17 |
+
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
18 |
+
scaled_img = transform_test(inp)
|
19 |
+
torch_images = scaled_img.unsqueeze(0)
|
20 |
+
|
21 |
+
with torch.no_grad():
|
22 |
+
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(
|
23 |
+
torch_images)
|
24 |
+
pred = torch.nn.functional.softmax(concat_logits)
|
25 |
+
return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))}
|
26 |
+
|
27 |
+
|
28 |
+
inputs = gr.inputs.Image()
|
29 |
+
outputs = gr.outputs.Label(num_top_classes=10)
|
30 |
+
gr.Interface(fn=predict, inputs=inputs, outputs=outputs,
|
31 |
+
title="200 Bird Species Classifications with NTS-NET (From CUB 200)").launch()
|
32 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
analytics-python==1.4.0
|
2 |
+
backoff==1.10.0
|
3 |
+
bcrypt==3.2.0
|
4 |
+
certifi==2021.10.8
|
5 |
+
cffi==1.15.0
|
6 |
+
charset-normalizer==2.0.7
|
7 |
+
click==8.0.3
|
8 |
+
cryptography==35.0.0
|
9 |
+
cycler==0.11.0
|
10 |
+
ffmpy==0.3.0
|
11 |
+
Flask==2.0.2
|
12 |
+
Flask-CacheBuster==1.0.0
|
13 |
+
Flask-Cors==3.0.10
|
14 |
+
Flask-Login==0.5.0
|
15 |
+
gradio==2.4.1
|
16 |
+
idna==3.3
|
17 |
+
itsdangerous==2.0.1
|
18 |
+
Jinja2==3.0.2
|
19 |
+
kiwisolver==1.3.2
|
20 |
+
markdown2==2.4.1
|
21 |
+
MarkupSafe==2.0.1
|
22 |
+
matplotlib==3.4.3
|
23 |
+
monotonic==1.6
|
24 |
+
numpy==1.21.3
|
25 |
+
pandas==1.3.4
|
26 |
+
paramiko==2.8.0
|
27 |
+
Pillow==8.4.0
|
28 |
+
pkg_resources==0.0.0
|
29 |
+
pycparser==2.20
|
30 |
+
pycryptodome==3.11.0
|
31 |
+
pydub==0.25.1
|
32 |
+
PyNaCl==1.4.0
|
33 |
+
pyparsing==3.0.3
|
34 |
+
python-dateutil==2.8.2
|
35 |
+
pytz==2021.3
|
36 |
+
requests==2.26.0
|
37 |
+
six==1.16.0
|
38 |
+
torch==1.10.0
|
39 |
+
torchvision==0.11.1
|
40 |
+
typing-extensions==3.10.0.2
|
41 |
+
urllib3==1.26.7
|
42 |
+
Werkzeug==2.0.2
|