ffcm commited on
Commit
908930d
1 Parent(s): 00e0338

updates desc

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. networktorch.py +0 -50
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
- import pickle
3
  import torch
4
  import numpy as np
5
 
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
  with open('cnn_model.bin', 'rb') as f:
9
- # nn = pickle.load(f)
10
  nn = torch.load(f, map_location=torch.device('cpu'))
11
 
12
  nn.to(device)
@@ -24,10 +22,16 @@ def predict(input):
24
  return dict(enumerate(p.tolist()))
25
 
26
 
 
 
 
 
 
 
27
  demo = gr.Interface(
28
  fn=predict,
29
  title='ConvNet for handwritten digits classification',
30
- description='Created with PyTorch.',
31
  inputs=[
32
  gr.Sketchpad(
33
  shape=(28, 28),
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
 
5
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
 
7
  with open('cnn_model.bin', 'rb') as f:
 
8
  nn = torch.load(f, map_location=torch.device('cpu'))
9
 
10
  nn.to(device)
 
22
  return dict(enumerate(p.tolist()))
23
 
24
 
25
+ desc = """\
26
+ This project uses a Convolutional Neural Network to classify handwritten digits.
27
+ Trained on the MNIST dataset.
28
+ Use most of the drawing area for better results.
29
+ """
30
+
31
  demo = gr.Interface(
32
  fn=predict,
33
  title='ConvNet for handwritten digits classification',
34
+ description=desc,
35
  inputs=[
36
  gr.Sketchpad(
37
  shape=(28, 28),
networktorch.py DELETED
@@ -1,50 +0,0 @@
1
- from torch import nn
2
-
3
-
4
- class NeuralNetworkTorch(nn.Module):
5
- def __init__(self):
6
- super().__init__()
7
-
8
- self.stack = nn.Sequential(
9
- nn.Linear(784, 64),
10
- nn.Sigmoid(),
11
-
12
- nn.Linear(64, 10),
13
- nn.Sigmoid()
14
- )
15
-
16
- def forward(self, x):
17
- return self.stack(x)
18
-
19
-
20
- class ConvNeuralNetworkTorch(nn.Module):
21
- def __init__(self):
22
- super().__init__()
23
-
24
- self.conv = nn.Sequential(
25
- nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
26
- nn.ReLU(),
27
-
28
- nn.MaxPool2d(kernel_size=2, stride=2),
29
-
30
- nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
31
- nn.ReLU(),
32
-
33
- # nn.MaxPool2d(kernel_size=2, stride=2),
34
- )
35
-
36
- self.fc = nn.Sequential(
37
- nn.Linear(16 * 14 * 14, 10),
38
- nn.Sigmoid(),
39
- )
40
-
41
- def forward(self, x):
42
- # we do some reshaping here simply to avoid making changes to the caller
43
- # so it continues to work with the fully conected network above
44
- x = x.reshape(-1, 1, 28, 28) / 255
45
-
46
- conv_output = self.conv(x)
47
- flat = conv_output.reshape(len(x), -1)
48
- final_output = self.fc(flat)
49
-
50
- return final_output