suepsakun commited on
Commit
e418279
·
verified ·
1 Parent(s): d4165b1

Upload [skooldio]_thai_handwritten_recognition_app_by_gradio (2).py

Browse files
[skooldio]_thai_handwritten_recognition_app_by_gradio (2).py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """[Skooldio] Thai Handwritten Recognition App by Gradio
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Feza00drEEejwPVzgdTmz60KArcms_w_
8
+ """
9
+
10
+ !pip install gradio==3.35.0
11
+ !pip install torchvision
12
+
13
+ !wget https://github.com/biodatlab/deep-learning-skooldio/raw/master/saved_model/thai_digit_net.pth
14
+
15
+ import numpy as np
16
+ import torch
17
+ from pathlib import Path
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+ import gradio as gr
23
+
24
+
25
+ transform = transforms.Compose([
26
+ transforms.Resize((28, 28)),
27
+ transforms.Grayscale(),
28
+ transforms.ToTensor()
29
+ ])
30
+ labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
31
+ LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label
32
+
33
+
34
+ # Load model using DropoutThaiDigit instead
35
+ class DropoutThaiDigit(nn.Module):
36
+ def __init__(self):
37
+ super(DropoutThaiDigit, self).__init__()
38
+ self.fc1 = nn.Linear(28 * 28, 392)
39
+ self.fc2 = nn.Linear(392, 196)
40
+ self.fc3 = nn.Linear(196, 98)
41
+ self.fc4 = nn.Linear(98, 10)
42
+ self.dropout = nn.Dropout(0.1)
43
+
44
+ def forward(self, x):
45
+ x = x.view(-1, 28 * 28)
46
+ x = self.fc1(x)
47
+ x = F.relu(x)
48
+ x = self.dropout(x)
49
+ x = self.fc2(x)
50
+ x = F.relu(x)
51
+ x = self.dropout(x)
52
+ x = self.fc3(x)
53
+ x = F.relu(x)
54
+ x = self.dropout(x)
55
+ x = self.fc4(x)
56
+ return x
57
+
58
+
59
+ model = DropoutThaiDigit()
60
+ model.load_state_dict(torch.load("thai_digit_net.pth"))
61
+ model.eval()
62
+
63
+ def predict(img):
64
+ if img.get("composite") is not None:
65
+ if img["composite"].sum() == 0:
66
+ return {"No input sketch": 0.0}
67
+
68
+ img_data = img['composite']
69
+ img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
70
+ img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
71
+
72
+ with torch.no_grad():
73
+ probs = model(img_tensor).softmax(dim=1).squeeze()
74
+
75
+ probs, indices = torch.topk(probs, 5)
76
+ return {LABELS[i]: float(p) for i, p in zip(indices.tolist(), probs.tolist())}
77
+
78
+ demo = gr.Interface(
79
+ fn=predict,
80
+ inputs=gr.Sketchpad(
81
+ label="Draw Here",
82
+ image_mode="L",
83
+ width=400,
84
+ height=350
85
+ ),
86
+ outputs=gr.Label(label="Guess"),
87
+ title="Thai Digit Handwritten Classification",
88
+ description="วาดเลขไทยตั้งแต่ ๐ ถึง ๙",
89
+ live=True
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()
94
+
95
+ gr.Interface(
96
+ fn=predict,
97
+ inputs=gr.Sketchpad(label="Draw Here", brush_radius=5, type="pil", shape=(120, 120)),
98
+ outputs=gr.Label(label="Guess"),
99
+ title="Thai Digit Handwritten Classification",
100
+ live=True
101
+ ).launch(enable_queue=True)
102
+