gaviego commited on
Commit
a8cceae
1 Parent(s): 6afcd6e

clear working

Browse files
Files changed (2) hide show
  1. app.py +48 -13
  2. mnist.pth +0 -0
app.py CHANGED
@@ -8,6 +8,9 @@ from models import Net,NetConv
8
  net = torch.load('mnist.pth')
9
  net.eval()
10
 
 
 
 
11
  def predict(img):
12
  arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
13
  arr = np.expand_dims(arr, axis=0) # Add batch dimension
@@ -16,21 +19,53 @@ def predict(img):
16
  topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
17
  return [str(k) for k in topk_indices[0].tolist()]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with gr.Blocks() as iface:
20
  gr.Markdown("# MNIST + Gradio End to End")
21
  gr.HTML("Shows end to end MNIST training with Gradio interface")
22
- with gr.Row():
23
- with gr.Column():
24
- sp = gr.Sketchpad(shape=(28, 28))
25
- with gr.Row():
26
- with gr.Column():
27
- pred_button = gr.Button("Predict")
28
- with gr.Column():
29
- clear = gr.Button("Clear")
30
- with gr.Column():
31
- label1 = gr.Label(label='1st Pred')
32
- label2 = gr.Label(label='2nd Pred')
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  pred_button.click(predict, inputs=sp, outputs=[label1,label2])
35
- clear.click(lambda: None, None, sp, queue=False)
 
 
 
 
36
  iface.launch()
 
8
  net = torch.load('mnist.pth')
9
  net.eval()
10
 
11
+ net_conv = torch.load('mnist_conv.pth')
12
+ net_conv.eval()
13
+
14
  def predict(img):
15
  arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
16
  arr = np.expand_dims(arr, axis=0) # Add batch dimension
 
19
  topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
20
  return [str(k) for k in topk_indices[0].tolist()]
21
 
22
+ def predict_conv(img):
23
+ arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
24
+ arr = np.expand_dims(arr, axis=0) # Conv needs one more dimension
25
+ arr = np.expand_dims(arr, axis=0) # Add batch dimension
26
+ arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor
27
+ output = net_conv(arr)
28
+ topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
29
+ return [str(k) for k in topk_indices[0].tolist()]
30
+
31
+
32
+
33
+
34
+
35
  with gr.Blocks() as iface:
36
  gr.Markdown("# MNIST + Gradio End to End")
37
  gr.HTML("Shows end to end MNIST training with Gradio interface")
38
+ with gr.Tab("Linear Model"):
39
+ with gr.Row():
40
+ with gr.Column():
41
+ sp = gr.Sketchpad(shape=(28, 28))
42
+ with gr.Row():
43
+ with gr.Column():
44
+ pred_button = gr.Button("Predict")
45
+ with gr.Column():
46
+ clear_button = gr.Button("Clear")
47
+ with gr.Column():
48
+ label1 = gr.Label(label='1st Pred')
49
+ label2 = gr.Label(label='2nd Pred')
50
+
51
+ with gr.Tab("Convolution Model"):
52
+ with gr.Row():
53
+ with gr.Column():
54
+ sp_conv = gr.Sketchpad(shape=(28, 28))
55
+ with gr.Row():
56
+ with gr.Column():
57
+ pred_conv_button = gr.Button("Predict")
58
+ with gr.Column():
59
+ clear_button_conv = gr.Button("Clear")
60
+ with gr.Column():
61
+ label1_conv = gr.Label(label='1st Pred')
62
+ label2_conv = gr.Label(label='2nd Pred')
63
+ def clear():
64
+ return ['','',None,'','',None]
65
  pred_button.click(predict, inputs=sp, outputs=[label1,label2])
66
+ pred_conv_button.click(predict_conv, inputs=sp_conv, outputs=[label1_conv,label2_conv])
67
+ clear_button.click( lambda: ['','',None], None, [label1,label2,sp,], queue=False)
68
+ clear_button_conv.click( lambda: ['','',None], None, [label1_conv,label2_conv, sp_conv], queue=False)
69
+
70
+
71
  iface.launch()
mnist.pth CHANGED
Binary files a/mnist.pth and b/mnist.pth differ