balakrish181 commited on
Commit
58ee800
1 Parent(s): 41ea819

first commit

Browse files
mnist_app/__pycache__/model.cpython-39.pyc ADDED
Binary file (1.68 kB). View file
 
mnist_app/app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import load_model
2
+
3
+ import gradio as gr
4
+ import os
5
+ import torch
6
+
7
+
8
+
9
+ model,transforming,classes = load_model()
10
+
11
+ def predict(img):
12
+ img = transforming(img)
13
+ img = img.unsqueeze(0)
14
+
15
+
16
+ model.eval()
17
+ with torch.inference_mode():
18
+ pred_probs = torch.softmax(model(img), dim=1)
19
+ return {str(i): float(pred_probs[0][i]) for i in range(len(pred_probs[0]))}
20
+
21
+ title = 'MNIST Digit Prediction'
22
+ description = 'Predict handwritten digits (0-9) using a trained model.'
23
+ inputs = gr.Image(type='pil', label='Upload an image of a digit')
24
+ outputs = gr.Label(num_top_classes=3, label='Predictions')
25
+ demo = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=description)
26
+
27
+ # Launch the interface
28
+ demo.launch()
mnist_app/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6f4516312e2020aa20cde939bfa5663d3001cfe79ebd70ef3ab75f495f05298
3
+ size 24612
mnist_app/examples/0_mnist.png ADDED
mnist_app/examples/3.png ADDED
mnist_app/model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import transforms
5
+
6
+
7
+ class MnistModel(nn.Module):
8
+
9
+ classes = ['0 - zero',
10
+ '1 - one',
11
+ '2 - two',
12
+ '3 - three',
13
+ '4 - four',
14
+ '5 - five',
15
+ '6 - six',
16
+ '7 - seven',
17
+ '8 - eight',
18
+ '9 - nine']
19
+ def __init__(self, *args, **kwargs) -> None:
20
+ super().__init__(*args, **kwargs)
21
+ self.conv1 = nn.Conv2d(1, 3, 3)
22
+ self.conv2 = nn.Conv2d(3, 6, 3)
23
+ self.maxpool = nn.MaxPool2d(2, 2)
24
+ self.fc1 = nn.Linear(150, 32)
25
+ self.fc2 = nn.Linear(32, 10)
26
+ #self.fc3 = nn.Linear(32, 10)
27
+ self.dropout = nn.Dropout(0.3)
28
+
29
+ def forward(self, x):
30
+ l1 = nn.ReLU()(self.conv1(x))
31
+ l1 = self.maxpool(l1)
32
+ l2 = nn.ReLU()(self.conv2(l1))
33
+ l2 = self.maxpool(l2)
34
+ fc = torch.flatten(l2, 1)
35
+ fc1 = nn.ReLU()(self.fc1(fc))
36
+ fc1 = self.dropout(fc1)
37
+ #fc2 = nn.ReLU()(self.fc2(fc1))
38
+ out = self.fc2(fc1)
39
+ return out
40
+
41
+
42
+ def load_model():
43
+ model = MnistModel()
44
+ transforming = transforms.Compose([
45
+ transforms.Resize((28,28)),
46
+ transforms.ToTensor(),
47
+ transforms.Grayscale(num_output_channels=1)
48
+ ])
49
+
50
+ model.load_state_dict(torch.load('demos/mnist_app/best_model.pth',map_location='cpu'))
51
+
52
+ return model,transforming,model.classes
53
+
54
+ if __name__=='__main__':
55
+ pass
mnist_app/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.31.5
2
+ pathlib==1.0.1
3
+ torch==2.2.2
4
+ torchvision==0.17.2
5
+