suepsakun commited on
Commit
2829422
·
verified ·
1 Parent(s): e418279

Delete [skooldio]_thai_handwritten_recognition_app_by_gradio (2).py

Browse files
[skooldio]_thai_handwritten_recognition_app_by_gradio (2).py DELETED
@@ -1,102 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """[Skooldio] Thai Handwritten Recognition App by Gradio
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1Feza00drEEejwPVzgdTmz60KArcms_w_
8
- """
9
-
10
- !pip install gradio==3.35.0
11
- !pip install torchvision
12
-
13
- !wget https://github.com/biodatlab/deep-learning-skooldio/raw/master/saved_model/thai_digit_net.pth
14
-
15
- import numpy as np
16
- import torch
17
- from pathlib import Path
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from PIL import Image
21
- from torchvision import transforms
22
- import gradio as gr
23
-
24
-
25
- transform = transforms.Compose([
26
- transforms.Resize((28, 28)),
27
- transforms.Grayscale(),
28
- transforms.ToTensor()
29
- ])
30
- labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
31
- LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label
32
-
33
-
34
- # Load model using DropoutThaiDigit instead
35
- class DropoutThaiDigit(nn.Module):
36
- def __init__(self):
37
- super(DropoutThaiDigit, self).__init__()
38
- self.fc1 = nn.Linear(28 * 28, 392)
39
- self.fc2 = nn.Linear(392, 196)
40
- self.fc3 = nn.Linear(196, 98)
41
- self.fc4 = nn.Linear(98, 10)
42
- self.dropout = nn.Dropout(0.1)
43
-
44
- def forward(self, x):
45
- x = x.view(-1, 28 * 28)
46
- x = self.fc1(x)
47
- x = F.relu(x)
48
- x = self.dropout(x)
49
- x = self.fc2(x)
50
- x = F.relu(x)
51
- x = self.dropout(x)
52
- x = self.fc3(x)
53
- x = F.relu(x)
54
- x = self.dropout(x)
55
- x = self.fc4(x)
56
- return x
57
-
58
-
59
- model = DropoutThaiDigit()
60
- model.load_state_dict(torch.load("thai_digit_net.pth"))
61
- model.eval()
62
-
63
- def predict(img):
64
- if img.get("composite") is not None:
65
- if img["composite"].sum() == 0:
66
- return {"No input sketch": 0.0}
67
-
68
- img_data = img['composite']
69
- img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
70
- img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
71
-
72
- with torch.no_grad():
73
- probs = model(img_tensor).softmax(dim=1).squeeze()
74
-
75
- probs, indices = torch.topk(probs, 5)
76
- return {LABELS[i]: float(p) for i, p in zip(indices.tolist(), probs.tolist())}
77
-
78
- demo = gr.Interface(
79
- fn=predict,
80
- inputs=gr.Sketchpad(
81
- label="Draw Here",
82
- image_mode="L",
83
- width=400,
84
- height=350
85
- ),
86
- outputs=gr.Label(label="Guess"),
87
- title="Thai Digit Handwritten Classification",
88
- description="วาดเลขไทยตั้งแต่ ๐ ถึง ๙",
89
- live=True
90
- )
91
-
92
- if __name__ == "__main__":
93
- demo.launch()
94
-
95
- gr.Interface(
96
- fn=predict,
97
- inputs=gr.Sketchpad(label="Draw Here", brush_radius=5, type="pil", shape=(120, 120)),
98
- outputs=gr.Label(label="Guess"),
99
- title="Thai Digit Handwritten Classification",
100
- live=True
101
- ).launch(enable_queue=True)
102
-