|
import numpy as np |
|
import torch |
|
from pathlib import Path |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio as gr |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((28, 28)), |
|
transforms.Grayscale(), |
|
transforms.ToTensor() |
|
]) |
|
labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"] |
|
LABELS = {i:k for i, k in enumerate(labels)} |
|
|
|
|
|
|
|
class DropoutThaiDigit(nn.Module): |
|
def __init__(self): |
|
super(DropoutThaiDigit, self).__init__() |
|
self.fc1 = nn.Linear(28 * 28, 392) |
|
self.fc2 = nn.Linear(392, 196) |
|
self.fc3 = nn.Linear(196, 98) |
|
self.fc4 = nn.Linear(98, 10) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x): |
|
x = x.view(-1, 28 * 28) |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc3(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc4(x) |
|
return x |
|
|
|
|
|
model = DropoutThaiDigit() |
|
model.load_state_dict(torch.load("thai_digit_net.pth")) |
|
model.eval() |
|
|
|
|
|
def predict(img): |
|
""" |
|
Predict function takes image and return top 5 predictions |
|
as a dictionary: |
|
{label: confidence, label: confidence, ...} |
|
""" |
|
if img.get("composite") is not None: |
|
if img["composite"].sum() == 0: |
|
return {"No input sketch": 0.0} |
|
|
|
img_data = img['composite'] |
|
img_gray = Image.fromarray(img_data).convert('L').resize((28, 28)) |
|
img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
probs = model(img_tensor).softmax(dim=1).squeeze() |
|
|
|
probs, indices = torch.topk(probs, 5) |
|
probs, indices = probs.tolist(), indices.tolist() |
|
return {LABELS[i]: float(v) for i, v in zip(indices, probs)} |
|
|
|
js_func = """ |
|
function refresh() { |
|
const url = new URL(window.location); |
|
|
|
if (url.searchParams.get('__theme') !== 'dark') { |
|
url.searchParams.set('__theme', 'dark'); |
|
window.location.href = url.href; |
|
} |
|
} |
|
""" |
|
|
|
with gr.Blocks(js=js_func) as demo: |
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Sketchpad( |
|
label="Draw Here", |
|
brush=gr.Brush(default_size=14, default_color="#FFFFFF", colors=["#FFFFFF"]), |
|
image_mode="L", |
|
layers=False, |
|
eraser=None, |
|
width=400, |
|
height=350 |
|
), |
|
outputs=gr.Label(label="Guess"), |
|
title="Thai Digit Handwritten Classification", |
|
description="ทดลองวาดภาพตัวอักษรเลขไทยลงใน Sketchpad ด้านล่างเพื่อทำนายผลตัวเลข ตั้งแต่ ๐ (ศูนย์) ๑ (หนึ่ง) ๒ (สอง) ๓ (สาม) ๔ (สี่) ๕ (ห้า) ๖ (หก) ๗ (เจ็ด) ๘ (แปด) จนถึง ๙ (เก้า)", |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|