Samuel Diaz commited on
Commit
0395eb2
1 Parent(s): b0cad4d

First Approach

Browse files
Files changed (1) hide show
  1. app.py +213 -1
app.py CHANGED
@@ -1,8 +1,220 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def image_mod(image):
4
- return image.rotate(45)
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  iface = gr.Interface(image_mod, gr.Image(type="pil"), "image")
7
 
8
  iface.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ from sklearn.preprocessing import LabelEncoder
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision
11
+ from torchvision import transforms
12
+ import torchvision.models as models
13
+ from torchvision.datasets import ImageFolder
14
+ from torch.utils.data.dataset import Dataset
15
+ from torch.utils.data import Dataset, random_split, DataLoader
16
+ from torch.utils.data import DataLoader
17
+ from sklearn.model_selection import train_test_split
18
+ import torchmetrics
19
+ from tqdm.notebook import tqdm
20
 
21
+ class ImageClassificationBase(torch.nn.Module):
22
+ # training step
23
+ def training_step(self, batch):
24
+ img, targets = batch
25
+ out = self(img)
26
+ loss = F.nll_loss(out, targets)
27
+ return loss
28
+
29
+ # validation step
30
+ def validation_step(self, batch):
31
+ img, targets = batch
32
+ out = self(img)
33
+ loss = F.nll_loss(out, targets)
34
+ acc = accuracy(out, targets)
35
+ return {'val_acc':acc.detach(), 'val_loss':loss.detach()}
36
+
37
+ # validation epoch end
38
+ def validation_epoch_end(self, outputs):
39
+ batch_losses = [x['val_loss'] for x in outputs]
40
+ epoch_loss = torch.stack(batch_losses).mean()
41
+ batch_accs = [x['val_acc'] for x in outputs]
42
+ epoch_acc = torch.stack(batch_accs).mean()
43
+ return {'val_loss':epoch_loss.item(), 'val_acc':epoch_acc.item()}
44
+
45
+ # print result end epoch
46
+ def epoch_end(self, epoch, result):
47
+ print("Epoch [{}] : train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result["train_loss"], result["val_loss"], result["val_acc"]))
48
+
49
+ class DogBreedPretrainedWideResnet(ImageClassificationBase):
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ self.network = models.wide_resnet50_2(pretrained=True)
54
+ # Replace last layer
55
+ num_ftrs = self.network.fc.in_features
56
+ self.network.fc = torch.nn.Sequential(
57
+ torch.nn.Linear(num_ftrs, 120),
58
+ torch.nn.LogSoftmax(dim=1)
59
+ )
60
+
61
+ def forward(self, xb):
62
+ return self.network(xb)
63
+
64
+ def predict_single(img):
65
+ xb = img.unsqueeze(0) # adding extra dimension
66
+ xb = to_device(xb, device)
67
+ preds = model(xb) # change model object here
68
+ predictions = preds[0]
69
+
70
+ max_val, kls = torch.max(predictions, dim=0)
71
+ print('Predicted :', breeds[kls])
72
+ plt.imshow(img.permute(1,2,0))
73
+ plt.show()
74
+
75
+ def get_default_device():
76
+ if torch.cuda.is_available():
77
+ return torch.device('cuda')
78
+ else:
79
+ return torch.device('cpu')
80
+
81
+ def to_device(data, device):
82
+ if isinstance(data, (list, tuple)):
83
+ return [to_device(d, device) for d in data]
84
+ else:
85
+ return data.to(device, non_blocking=True)
86
+
87
+ def accuracy(outputs, labels):
88
+ _, preds = torch.max(outputs, dim=1)
89
+ return torch.tensor(torch.sum(preds == labels).item() / len(preds))
90
+
91
  def image_mod(image):
92
+ return predict_single(image)
93
 
94
+ device = get_default_device()
95
+ PATH = "./model/model.zip"
96
+ model = DogBreedPretrainedWideResnet()
97
+ model.load_state_dict(torch.load(PATH))
98
+ breeds=['Chihuahua',
99
+ 'Japanese spaniel',
100
+ 'Maltese dog',
101
+ 'Pekinese',
102
+ 'Shih Tzu',
103
+ 'Blenheim spaniel',
104
+ 'papillon',
105
+ 'toy terrier',
106
+ 'Rhodesian ridgeback',
107
+ 'Afghan hound',
108
+ 'basset',
109
+ 'beagle',
110
+ 'bloodhound',
111
+ 'bluetick',
112
+ 'black and tan coonhound',
113
+ 'Walker hound',
114
+ 'English foxhound',
115
+ 'redbone',
116
+ 'borzoi',
117
+ 'Irish wolfhound',
118
+ 'Italian greyhound',
119
+ 'whippet',
120
+ 'Ibizan hound',
121
+ 'Norwegian elkhound',
122
+ 'otterhound',
123
+ 'Saluki',
124
+ 'Scottish deerhound',
125
+ 'Weimaraner',
126
+ 'Staffordshire bullterrier',
127
+ 'American Staffordshire terrier',
128
+ 'Bedlington terrier',
129
+ 'Border terrier',
130
+ 'Kerry blue terrier',
131
+ 'Irish terrier',
132
+ 'Norfolk terrier',
133
+ 'Norwich terrier',
134
+ 'Yorkshire terrier',
135
+ 'wire haired fox terrier',
136
+ 'Lakeland terrier',
137
+ 'Sealyham terrier',
138
+ 'Airedale',
139
+ 'cairn',
140
+ 'Australian terrier',
141
+ 'Dandie Dinmont',
142
+ 'Boston bull',
143
+ 'miniature schnauzer',
144
+ 'giant schnauzer',
145
+ 'standard schnauzer',
146
+ 'Scotch terrier',
147
+ 'Tibetan terrier',
148
+ 'silky terrier',
149
+ 'soft coated wheaten terrier',
150
+ 'West Highland white terrier',
151
+ 'Lhasa',
152
+ 'flat coated retriever',
153
+ 'curly coated retriever',
154
+ 'golden retriever',
155
+ 'Labrador retriever',
156
+ 'Chesapeake Bay retriever',
157
+ 'German short haired pointer',
158
+ 'vizsla',
159
+ 'English setter',
160
+ 'Irish setter',
161
+ 'Gordon setter',
162
+ 'Brittany spaniel',
163
+ 'clumber',
164
+ 'English springer',
165
+ 'Welsh springer spaniel',
166
+ 'cocker spaniel',
167
+ 'Sussex spaniel',
168
+ 'Irish water spaniel',
169
+ 'kuvasz',
170
+ 'schipperke',
171
+ 'groenendael',
172
+ 'malinois',
173
+ 'briard',
174
+ 'kelpie',
175
+ 'komondor',
176
+ 'Old English sheepdog',
177
+ 'Shetland sheepdog',
178
+ 'collie',
179
+ 'Border collie',
180
+ 'Bouvier des Flandres',
181
+ 'Rottweiler',
182
+ 'German shepherd',
183
+ 'Doberman',
184
+ 'miniature pinscher',
185
+ 'Greater Swiss Mountain dog',
186
+ 'Bernese mountain dog',
187
+ 'Appenzeller',
188
+ 'EntleBucher',
189
+ 'boxer',
190
+ 'bull mastiff',
191
+ 'Tibetan mastiff',
192
+ 'French bulldog',
193
+ 'Great Dane',
194
+ 'Saint Bernard',
195
+ 'Eskimo dog',
196
+ 'malamute',
197
+ 'Siberian husky',
198
+ 'affenpinscher',
199
+ 'basenji',
200
+ 'pug',
201
+ 'Leonberg',
202
+ 'Newfoundland',
203
+ 'Great Pyrenees',
204
+ 'Samoyed',
205
+ 'Pomeranian',
206
+ 'chow',
207
+ 'keeshond',
208
+ 'Brabancon griffon',
209
+ 'Pembroke',
210
+ 'Cardigan',
211
+ 'toy poodle',
212
+ 'miniature poodle',
213
+ 'standard poodle',
214
+ 'Mexican hairless',
215
+ 'dingo',
216
+ 'dhole',
217
+ 'African hunting dog']
218
  iface = gr.Interface(image_mod, gr.Image(type="pil"), "image")
219
 
220
  iface.launch()