karanravindra commited on
Commit
beb2ed5
1 Parent(s): 54fd737

fix gradio demo

Browse files
Files changed (2) hide show
  1. app.py +6 -8
  2. requirements.txt +0 -1
app.py CHANGED
@@ -3,22 +3,19 @@ import time
3
  import gradio as gr
4
  import torch
5
  import torch.nn.functional as F
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  from digitnet import Model
9
 
10
  torch.set_grad_enabled(False)
11
 
12
- hf_hub_download("karanravindra/digitnet", filename="mnist", local_dir=".")
13
 
14
- model = Model(1, 16, num_classes=36)
15
- model.load_state_dict(
16
- torch.load("models/emnist-model.pth", weights_only=True, map_location="cpu")
17
- )
18
  model.eval()
19
 
20
- all_classes = list("0123456789abcdefghijklmnopqrstuvwxyz")
21
- all_classes = map(str.upper, all_classes)
22
 
23
 
24
  def predict(inputs: dict) -> dict[str, float]:
@@ -32,7 +29,6 @@ def predict(inputs: dict) -> dict[str, float]:
32
  logits = model(img)
33
 
34
  probs = torch.softmax(logits, dim=1).squeeze()
35
-
36
  return {c: p for c, p in zip(all_classes, probs.tolist())}
37
 
38
 
@@ -50,8 +46,10 @@ demo = gr.Interface(
50
  live=True,
51
  inputs=input,
52
  outputs="label",
 
53
  title="DigitNet",
54
  description="A simple handwritten number and letter classifier.\n\nDraw a digit or letter in the box below and see the model's predictions.",
 
55
  )
56
 
57
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  import torch
5
  import torch.nn.functional as F
6
+ from torchvision.utils import save_image
7
  from huggingface_hub import hf_hub_download
8
 
9
  from digitnet import Model
10
 
11
  torch.set_grad_enabled(False)
12
 
13
+ hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".")
14
 
15
+ model = Model.load_from_checkpoint("model.ckpt", map_location="cpu")
 
 
 
16
  model.eval()
17
 
18
+ all_classes = list(map(str.upper, "0123456789"))
 
19
 
20
 
21
  def predict(inputs: dict) -> dict[str, float]:
 
29
  logits = model(img)
30
 
31
  probs = torch.softmax(logits, dim=1).squeeze()
 
32
  return {c: p for c, p in zip(all_classes, probs.tolist())}
33
 
34
 
 
46
  live=True,
47
  inputs=input,
48
  outputs="label",
49
+ submit_btn=gr.Button("Predict"),
50
  title="DigitNet",
51
  description="A simple handwritten number and letter classifier.\n\nDraw a digit or letter in the box below and see the model's predictions.",
52
+ flagging_mode="never"
53
  )
54
 
55
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  digitnet @ git+https://github.com/karanravindra/digitnet@main
2
 
3
  # This file was autogenerated via `uv export`.
4
- -e .
5
  aiofiles==23.2.1
6
  aiohappyeyeballs==2.4.3
7
  aiohttp==3.10.10
 
1
  digitnet @ git+https://github.com/karanravindra/digitnet@main
2
 
3
  # This file was autogenerated via `uv export`.
 
4
  aiofiles==23.2.1
5
  aiohappyeyeballs==2.4.3
6
  aiohttp==3.10.10