gaviego commited on
Commit
6afcd6e
1 Parent(s): acdfdd8

convolution

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. app_conv.py +40 -0
  3. mnist_conv.pth +0 -0
  4. model.py → models.py +25 -1
  5. train.py +2 -2
app.py CHANGED
@@ -3,7 +3,7 @@ from PIL import Image
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
6
- import model
7
 
8
  net = torch.load('mnist.pth')
9
  net.eval()
 
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
6
+ from models import Net,NetConv
7
 
8
  net = torch.load('mnist.pth')
9
  net.eval()
app_conv.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from models import NetConv
8
+
9
+
10
+ net_conv = torch.load('mnist_conv.pth')
11
+ net_conv.eval()
12
+
13
+ def predict(img):
14
+ arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
15
+ arr.reshape(28,28)
16
+ arr = np.expand_dims(arr, axis=0) # Add batch dimension
17
+ arr = np.expand_dims(arr, axis=0) # Add batch dimension
18
+ arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor
19
+ output = net_conv(arr)
20
+ topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
21
+ return [str(k) for k in topk_indices[0].tolist()]
22
+
23
+ with gr.Blocks() as iface:
24
+ gr.Markdown("# MNIST + Gradio End to End")
25
+ gr.HTML("Shows end to end MNIST training with Gradio interface")
26
+ with gr.Row():
27
+ with gr.Column():
28
+ sp = gr.Sketchpad(shape=(28, 28))
29
+ with gr.Row():
30
+ with gr.Column():
31
+ pred_button = gr.Button("Predict")
32
+ with gr.Column():
33
+ clear = gr.Button("Clear")
34
+ with gr.Column():
35
+ label1 = gr.Label(label='1st Pred')
36
+ label2 = gr.Label(label='2nd Pred')
37
+
38
+ pred_button.click(predict, inputs=sp, outputs=[label1,label2])
39
+ clear.click(lambda: None, None, sp, queue=False)
40
+ iface.launch()
mnist_conv.pth ADDED
Binary file (904 kB). View file
 
model.py → models.py RENAMED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
  # Define the model
4
  class Net(nn.Module):
5
  def __init__(self):
@@ -14,4 +15,27 @@ class Net(nn.Module):
14
  x = torch.relu(self.fc1(x))
15
  x = torch.relu(self.fc2(x))
16
  x = torch.relu(self.fc3(x))
17
- return self.fc4(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  # Define the model
5
  class Net(nn.Module):
6
  def __init__(self):
 
15
  x = torch.relu(self.fc1(x))
16
  x = torch.relu(self.fc2(x))
17
  x = torch.relu(self.fc3(x))
18
+ return self.fc4(x)
19
+
20
+ class NetConv(nn.Module):
21
+ def __init__(self):
22
+ super(NetConv, self).__init__()
23
+ self.conv1 = nn.Conv2d(1, 32, 3)
24
+ self.conv2 = nn.Conv2d(32, 64, 3)
25
+ self.fc1 = nn.Linear(64 * 5 * 5, 128) # Corrected
26
+ self.fc2 = nn.Linear(128, 10)
27
+
28
+ def forward(self, x):
29
+ x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
30
+ x = F.max_pool2d(F.relu(self.conv2(x)), 2)
31
+ x = x.view(-1, self.num_flat_features(x))
32
+ x = F.relu(self.fc1(x))
33
+ x = self.fc2(x)
34
+ return F.log_softmax(x, dim=1)
35
+
36
+ def num_flat_features(self, x):
37
+ size = x.size()[1:]
38
+ num_features = 1
39
+ for s in size:
40
+ num_features *= s
41
+ return num_features
train.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import torch.optim as optim
4
  import torchvision
5
  import torchvision.transforms as transforms
6
- import model
7
  # Load the MNIST dataset
8
  train_set = torchvision.datasets.MNIST(root='./data', train=True,
9
  download=True, transform=transforms.ToTensor())
@@ -17,7 +17,7 @@ test_loader = torch.utils.data.DataLoader(test_set, batch_size=32,
17
 
18
 
19
 
20
- net = model.Net()
21
 
22
  # Use CrossEntropyLoss for multi-class classification
23
  criterion = nn.CrossEntropyLoss()
 
3
  import torch.optim as optim
4
  import torchvision
5
  import torchvision.transforms as transforms
6
+ from models import Net
7
  # Load the MNIST dataset
8
  train_set = torchvision.datasets.MNIST(root='./data', train=True,
9
  download=True, transform=transforms.ToTensor())
 
17
 
18
 
19
 
20
+ net = Net()
21
 
22
  # Use CrossEntropyLoss for multi-class classification
23
  criterion = nn.CrossEntropyLoss()