tshrusd commited on
Commit
a8fc39e
1 Parent(s): 9ac145a

first commmit

Browse files
Files changed (2) hide show
  1. app.py +66 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ import pytorch_lightning as pl
5
+
6
+ import torchvision.models as models
7
+ import gradio as gr
8
+
9
+
10
+ class Net(pl.LightningModule):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ # init a pretrained resnet
16
+ backbone = models.resnet50()
17
+ num_filters = backbone.fc.in_features
18
+ layers = list(backbone.children())[:-1]
19
+ self.feature_extractor = nn.Sequential(*layers)
20
+
21
+ # use the pretrained model to classify cifar-10 (10 image classes)
22
+ num_target_classes = 2
23
+ self.classifier = nn.Linear(num_filters, num_target_classes)
24
+
25
+ def forward(self, x):
26
+ self.feature_extractor.eval()
27
+ with torch.no_grad():
28
+ representations = self.feature_extractor(x).flatten(1)
29
+ x = self.classifier(representations)
30
+ return x
31
+
32
+
33
+ net = Net()
34
+ net = net.load_from_checkpoint(
35
+ './resnet50_transfer.ckpt', map_location=torch.device('cpu'))
36
+ net.eval()
37
+
38
+
39
+ labels = ['Cat', 'Dog']
40
+ val_transform = transforms.Compose(
41
+ [
42
+ transforms.ToTensor(),
43
+ transforms.Resize(256),
44
+ transforms.CenterCrop(224),
45
+ ]
46
+ )
47
+
48
+
49
+ def inference(img):
50
+ img = val_transform(img)
51
+ img = torch.unsqueeze(img, 0)
52
+ with torch.no_grad():
53
+ pred = net(img)
54
+ pred = torch.softmax(pred, dim=1)
55
+ scores = pred.detach().numpy()[0]
56
+ confidences = {labels[i]: float(scores[i]) for i in range(len(labels))}
57
+ return confidences
58
+
59
+
60
+ iface = gr.Interface(
61
+ fn=inference,
62
+ inputs=gr.components.Image(type="numpy"),
63
+ outputs=gr.Label(num_top_classes=2)
64
+ )
65
+
66
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pytorch
2
+ pytorch-lightning
3
+ torchvision