suepsakun commited on
Commit
2385619
·
verified ·
1 Parent(s): 2829422

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from pathlib import Path
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import gradio as gr
9
+
10
+
11
+ transform = transforms.Compose([
12
+ transforms.Resize((28, 28)),
13
+ transforms.Grayscale(),
14
+ transforms.ToTensor()
15
+ ])
16
+ labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
17
+ LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label
18
+
19
+
20
+ # Load model using DropoutThaiDigit instead
21
+ class DropoutThaiDigit(nn.Module):
22
+ def __init__(self):
23
+ super(DropoutThaiDigit, self).__init__()
24
+ self.fc1 = nn.Linear(28 * 28, 392)
25
+ self.fc2 = nn.Linear(392, 196)
26
+ self.fc3 = nn.Linear(196, 98)
27
+ self.fc4 = nn.Linear(98, 10)
28
+ self.dropout = nn.Dropout(0.1)
29
+
30
+ def forward(self, x):
31
+ x = x.view(-1, 28 * 28)
32
+ x = self.fc1(x)
33
+ x = F.relu(x)
34
+ x = self.dropout(x)
35
+ x = self.fc2(x)
36
+ x = F.relu(x)
37
+ x = self.dropout(x)
38
+ x = self.fc3(x)
39
+ x = F.relu(x)
40
+ x = self.dropout(x)
41
+ x = self.fc4(x)
42
+ return x
43
+
44
+
45
+ model = DropoutThaiDigit()
46
+ model.load_state_dict(torch.load("thai_digit_net.pth"))
47
+ model.eval()
48
+
49
+
50
+ def predict(img):
51
+ """
52
+ Predict function takes image and return top 5 predictions
53
+ as a dictionary:
54
+ {label: confidence, label: confidence, ...}
55
+ """
56
+ if img.get("composite") is not None:
57
+ if img["composite"].sum() == 0:
58
+ return {"No input sketch": 0.0}
59
+
60
+ img_data = img['composite']
61
+ img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
62
+ img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
63
+
64
+ # Make prediction
65
+ with torch.no_grad():
66
+ probs = model(img_tensor).softmax(dim=1).squeeze()
67
+
68
+ probs, indices = torch.topk(probs, 5) # select top 5
69
+ probs, indices = probs.tolist(), indices.tolist() # transform to list
70
+ return {LABELS[i]: float(v) for i, v in zip(indices, probs)}
71
+
72
+ js_func = """
73
+ function refresh() {
74
+ const url = new URL(window.location);
75
+
76
+ if (url.searchParams.get('__theme') !== 'dark') {
77
+ url.searchParams.set('__theme', 'dark');
78
+ window.location.href = url.href;
79
+ }
80
+ }
81
+ """
82
+
83
+ with gr.Blocks(js=js_func) as demo:
84
+ gr.Interface(
85
+ fn=predict,
86
+ inputs=gr.Sketchpad(
87
+ label="Draw Here",
88
+ brush=gr.Brush(default_size=14, default_color="#FFFFFF", colors=["#FFFFFF"]),
89
+ image_mode="L",
90
+ layers=False,
91
+ eraser=None,
92
+ width=400,
93
+ height=350
94
+ ),
95
+ outputs=gr.Label(label="Guess"),
96
+ title="Thai Digit Handwritten Classification",
97
+ description="ทดลองวาดภาพตัวอักษรเลขไทยลงใน Sketchpad ด้านล่างเพื่อทำนายผลตัวเลข ตั้งแต่ ๐ (ศูนย์) ๑ (หนึ่ง) ๒ (สอง) ๓ (สาม) ๔ (สี่) ๕ (ห้า) ๖ (หก) ๗ (เจ็ด) ๘ (แปด) จนถึง ๙ (เก้า)",
98
+ live=True
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ demo.launch()