aliabd HF staff commited on
Commit
23082b3
β€’
1 Parent(s): ca56557

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +9 -4
run.py CHANGED
@@ -1,5 +1,6 @@
1
  from pathlib import Path
2
 
 
3
  import torch
4
  import gradio as gr
5
  from torch import nn
@@ -30,10 +31,10 @@ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
30
  model.load_state_dict(state_dict, strict=False)
31
  model.eval()
32
 
33
- def predict(input):
34
- im = input
35
  if im is None:
36
  return None
 
37
 
38
  x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
39
 
@@ -47,5 +48,9 @@ def predict(input):
47
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
48
 
49
 
50
- interface = gr.Interface(predict, inputs=gr.templates.Sketchpad(label="Draw Here"), outputs=gr.Label(label="Guess"), theme="default", css=".footer{display:none !important}", live=True)
51
- interface.launch(enable_queue=False)
 
 
 
 
 
1
  from pathlib import Path
2
 
3
+ import numpy as np
4
  import torch
5
  import gradio as gr
6
  from torch import nn
 
31
  model.load_state_dict(state_dict, strict=False)
32
  model.eval()
33
 
34
+ def predict(im):
 
35
  if im is None:
36
  return None
37
+ im = np.asarray(im.resize((28, 28)))
38
 
39
  x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
40
 
 
48
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
49
 
50
 
51
+ interface = gr.Interface(predict,
52
+ inputs=gr.Sketchpad(label="Draw Here", brush_radius=5, type="pil", shape=(120, 120)),
53
+ outputs=gr.Label(label="Guess"),
54
+ live=True)
55
+
56
+ interface.queue().launch()