Spaces:
Runtime error
Runtime error
deploying to spaces
Browse files- app.py +127 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from torchvision import models, transforms
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
import pickle
|
9 |
+
from PIL import Image
|
10 |
+
import datetime
|
11 |
+
import gdown
|
12 |
+
import urllib.request
|
13 |
+
import gradio as gr
|
14 |
+
import markdown
|
15 |
+
|
16 |
+
url = 'https://drive.google.com/file/d/1qKiyp4r8SqUtz2ZWk3E6oZhyhl6t8lyG/view?usp=sharing'
|
17 |
+
path_class_names = "./class_names_restnet_leeds_butterfly.pkl"
|
18 |
+
gdown.download(url, path_class_names, quiet=False)
|
19 |
+
|
20 |
+
url = 'https://drive.google.com/file/d/1Ep2YWU4M-yVkF7AFP3aD1sVhuriIDzFe/view?usp=sharing'
|
21 |
+
path_model = "./model_state_restnet_leeds_butterfly.pth"
|
22 |
+
gdown.download(url, path_model, quiet=False)
|
23 |
+
|
24 |
+
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Red_postman_butterfly_%28Heliconius_erato%29.jpg/1599px-Red_postman_butterfly_%28Heliconius_erato%29.jpg"
|
25 |
+
path_input = "./h_erato.jpg"
|
26 |
+
urllib.request.urlretrieve(url, filename=path_input)
|
27 |
+
|
28 |
+
url = "https://www.ukbutterflies.co.uk/photo_album/source/664a285ca7b4379147d598ea5127228f.jpg"
|
29 |
+
path_input = "./d_plexippus.jpg"
|
30 |
+
urllib.request.urlretrieve(url, filename=path_input)
|
31 |
+
|
32 |
+
# normalisation
|
33 |
+
data_transforms_test = transforms.Compose([
|
34 |
+
transforms.Resize(256),
|
35 |
+
transforms.CenterCrop(224),
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
38 |
+
])
|
39 |
+
|
40 |
+
class_names = pickle.load(open(path_class_names, "rb"))
|
41 |
+
|
42 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
model_ft = models.resnet18(pretrained=True)
|
45 |
+
num_ftrs = model_ft.fc.in_features
|
46 |
+
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
|
47 |
+
model_ft = model_ft.to(device)
|
48 |
+
model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device)))
|
49 |
+
|
50 |
+
# Proper labeling
|
51 |
+
id_to_name = {
|
52 |
+
'001_Danaus Plexippus': 'Danaus plexippus - Monarch',
|
53 |
+
'002_Heliconius Charitonius': 'Heliconius charitonius - Zebra Longwing',
|
54 |
+
'003_Heliconius Erato': 'Heliconius erato - Red Postman',
|
55 |
+
'004_Junonia Coenia': 'Junonia coenia - Common Buckeye',
|
56 |
+
'005_Lycaena Phlaeas': 'Lycaena phlaeas - Small Copper',
|
57 |
+
'006_Nymphalis Antiopa': 'Nymphalis antiopa - Mourning Cloak',
|
58 |
+
'007_Papilio Cresphontes': 'Papilio cresphontes - Giant Swallowtail',
|
59 |
+
'008_Pieris Rapae': 'Pieris rapae - Cabbage White',
|
60 |
+
'009_Vanessa Atalanta': 'Vanessa atalanta - Red Admiral',
|
61 |
+
'010_Vanessa Cardui': 'Vanessa cardui - Painted Lady',
|
62 |
+
}
|
63 |
+
|
64 |
+
def do_inference(img):
|
65 |
+
img_t = data_transforms_test(img)
|
66 |
+
batch_t = torch.unsqueeze(img_t, 0)
|
67 |
+
model_ft.eval()
|
68 |
+
# We don't need gradients for test, so wrap in
|
69 |
+
# no_grad to save memory
|
70 |
+
with torch.no_grad():
|
71 |
+
batch_t = batch_t.to(device)
|
72 |
+
# forward propagation
|
73 |
+
output = model_ft( batch_t)
|
74 |
+
# get prediction
|
75 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
76 |
+
output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
|
77 |
+
probs = probs.cpu().numpy()[0]
|
78 |
+
probs = probs[output]
|
79 |
+
labels = np.array(class_names)[output]
|
80 |
+
return {id_to_name[labels[i]]: round(float(probs[i]),2) for i in range(len(labels))}
|
81 |
+
|
82 |
+
im = gr.inputs.Image(shape=(512, 512), image_mode='RGB',
|
83 |
+
invert_colors=False, source="upload",
|
84 |
+
type="pil")
|
85 |
+
title = "Butterfly Classification Demo"
|
86 |
+
description = "A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset. Libraries: PyTorch, Gradio."
|
87 |
+
examples = [['./h_erato.jpg'],['d_plexippus.jpg']]
|
88 |
+
article_text = markdown.markdown('''
|
89 |
+
|
90 |
+
<h1 style="color:white">PyTorch image classification - A pretrained ResNet18 CNN trained on the <a href="http://www.josiahwang.com/dataset/leedsbutterfly/">Leeds Butterfly Dataset</a></h1>
|
91 |
+
<br>
|
92 |
+
<p>The Leeds Butterfly Dataset consists of 832 images in 10 classes:</p>
|
93 |
+
<ul>
|
94 |
+
<li>Danaus plexippus - Monarch</li>
|
95 |
+
<li>Heliconius charitonius - Zebra Longwing</li>
|
96 |
+
<li>Heliconius erato - Red Postman</li>
|
97 |
+
<li>Lycaena phlaeas - Small Copper</li>
|
98 |
+
<li>Junonia coenia - Common Buckeye</li>
|
99 |
+
<li>Nymphalis antiopa - Mourning Cloak</li>
|
100 |
+
<li>Papilio cresphontes - Giant Swallowtail</li>
|
101 |
+
<li>Pieris rapae - Cabbage White</li>
|
102 |
+
<li>Vanessa atalanta - Red Admiral</li>
|
103 |
+
<li>Vanessa cardui - Painted Lady</li>
|
104 |
+
</ul>
|
105 |
+
<br>
|
106 |
+
<p>Part of a dissertation project. Author: <a href="https://github.com/ttheland">ttheland</a></p>
|
107 |
+
''')
|
108 |
+
|
109 |
+
# enable queue
|
110 |
+
enable_queue = True
|
111 |
+
|
112 |
+
iface = gr.Interface(
|
113 |
+
do_inference,
|
114 |
+
im,
|
115 |
+
gr.outputs.Label(num_top_classes=2),
|
116 |
+
live=False,
|
117 |
+
interpretation=None,
|
118 |
+
title=title,
|
119 |
+
description=description,
|
120 |
+
article= article_text,
|
121 |
+
examples=examples,
|
122 |
+
enable_queue=enable_queue
|
123 |
+
)
|
124 |
+
|
125 |
+
iface.test.launch()
|
126 |
+
|
127 |
+
iface.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torchvision
|
2 |
+
gdown
|
3 |
+
markdown
|