Spaces:
Running
Running
karanravindra
commited on
Commit
•
beb2ed5
1
Parent(s):
54fd737
fix gradio demo
Browse files- app.py +6 -8
- requirements.txt +0 -1
app.py
CHANGED
@@ -3,22 +3,19 @@ import time
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from digitnet import Model
|
9 |
|
10 |
torch.set_grad_enabled(False)
|
11 |
|
12 |
-
hf_hub_download("karanravindra/digitnet", filename="
|
13 |
|
14 |
-
model = Model(
|
15 |
-
model.load_state_dict(
|
16 |
-
torch.load("models/emnist-model.pth", weights_only=True, map_location="cpu")
|
17 |
-
)
|
18 |
model.eval()
|
19 |
|
20 |
-
all_classes = list("
|
21 |
-
all_classes = map(str.upper, all_classes)
|
22 |
|
23 |
|
24 |
def predict(inputs: dict) -> dict[str, float]:
|
@@ -32,7 +29,6 @@ def predict(inputs: dict) -> dict[str, float]:
|
|
32 |
logits = model(img)
|
33 |
|
34 |
probs = torch.softmax(logits, dim=1).squeeze()
|
35 |
-
|
36 |
return {c: p for c, p in zip(all_classes, probs.tolist())}
|
37 |
|
38 |
|
@@ -50,8 +46,10 @@ demo = gr.Interface(
|
|
50 |
live=True,
|
51 |
inputs=input,
|
52 |
outputs="label",
|
|
|
53 |
title="DigitNet",
|
54 |
description="A simple handwritten number and letter classifier.\n\nDraw a digit or letter in the box below and see the model's predictions.",
|
|
|
55 |
)
|
56 |
|
57 |
if __name__ == "__main__":
|
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
+
from torchvision.utils import save_image
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
from digitnet import Model
|
10 |
|
11 |
torch.set_grad_enabled(False)
|
12 |
|
13 |
+
hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".")
|
14 |
|
15 |
+
model = Model.load_from_checkpoint("model.ckpt", map_location="cpu")
|
|
|
|
|
|
|
16 |
model.eval()
|
17 |
|
18 |
+
all_classes = list(map(str.upper, "0123456789"))
|
|
|
19 |
|
20 |
|
21 |
def predict(inputs: dict) -> dict[str, float]:
|
|
|
29 |
logits = model(img)
|
30 |
|
31 |
probs = torch.softmax(logits, dim=1).squeeze()
|
|
|
32 |
return {c: p for c, p in zip(all_classes, probs.tolist())}
|
33 |
|
34 |
|
|
|
46 |
live=True,
|
47 |
inputs=input,
|
48 |
outputs="label",
|
49 |
+
submit_btn=gr.Button("Predict"),
|
50 |
title="DigitNet",
|
51 |
description="A simple handwritten number and letter classifier.\n\nDraw a digit or letter in the box below and see the model's predictions.",
|
52 |
+
flagging_mode="never"
|
53 |
)
|
54 |
|
55 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
digitnet @ git+https://github.com/karanravindra/digitnet@main
|
2 |
|
3 |
# This file was autogenerated via `uv export`.
|
4 |
-
-e .
|
5 |
aiofiles==23.2.1
|
6 |
aiohappyeyeballs==2.4.3
|
7 |
aiohttp==3.10.10
|
|
|
1 |
digitnet @ git+https://github.com/karanravindra/digitnet@main
|
2 |
|
3 |
# This file was autogenerated via `uv export`.
|
|
|
4 |
aiofiles==23.2.1
|
5 |
aiohappyeyeballs==2.4.3
|
6 |
aiohttp==3.10.10
|