tirendazakademi commited on
Commit
813e68c
1 Parent(s): 791337a

added files

Browse files
Files changed (5) hide show
  1. app.py +61 -0
  2. cat.jpg +0 -0
  3. dog.jpg +0 -0
  4. image_classifier.pth +3 -0
  5. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import transforms
4
+ import gradio as gr
5
+
6
+ title = "PyTorch Cat vs Dog"
7
+ description = "Classifying cats and dogs with Pytorch"
8
+ article = "<p style='text-align: center'><a href='https://github.com/TirendazAcademy'>Github Repo</a></p>"
9
+
10
+ # The model architecture
11
+ class ImageClassifier(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.conv_layer_1 = nn.Sequential(
15
+ nn.Conv2d(3, 64, 3, padding=1),
16
+ nn.ReLU(),
17
+ nn.BatchNorm2d(64),
18
+ nn.MaxPool2d(2))
19
+ self.conv_layer_2 = nn.Sequential(
20
+ nn.Conv2d(64, 512, 3, padding=1),
21
+ nn.ReLU(),
22
+ nn.BatchNorm2d(512),
23
+ nn.MaxPool2d(2))
24
+ self.conv_layer_3 = nn.Sequential(
25
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
26
+ nn.ReLU(),
27
+ nn.BatchNorm2d(512),
28
+ nn.MaxPool2d(2))
29
+ self.classifier = nn.Sequential(
30
+ nn.Flatten(),
31
+ nn.Linear(in_features=512*3*3, out_features=2)
32
+ )
33
+ def forward(self, x: torch.Tensor):
34
+ x = self.conv_layer_1(x)
35
+ x = self.conv_layer_2(x)
36
+ x = self.conv_layer_3(x)
37
+ x = self.conv_layer_3(x)
38
+ x = self.conv_layer_3(x)
39
+ x = self.conv_layer_3(x)
40
+ x = self.classifier(x)
41
+ return x
42
+
43
+ model = ImageClassifier()
44
+ model.load_state_dict(torch.load('image_classifier.pth'))
45
+
46
+ def predict(inp):
47
+ image_transform = transforms.Compose([ transforms.Resize(size=(224,224)), transforms.ToTensor()])
48
+ labels = ['cat', 'dog']
49
+ inp = image_transform(inp).unsqueeze(dim=0)
50
+ with torch.no_grad():
51
+ prediction = torch.nn.functional.softmax(model(inp))
52
+ confidences = {labels[i]: float(prediction.squeeze()[i]) for i in range(len(labels))}
53
+ return confidences
54
+
55
+ gr.Interface(fn=predict,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs=gr.Label(num_top_classes=2),
58
+ title=title,
59
+ description=description,
60
+ article=article,
61
+ examples=['cat.jpg', 'dog.jpg']).launch()
cat.jpg ADDED
dog.jpg ADDED
image_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:472e52a33ecd629811329db47e8e95c7ccffbff0a36f9678571eea3faa172782
3
+ size 10689971
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.10.0
2
+ torchvision==0.11.0