File size: 1,901 Bytes
016a471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac00a9
 
 
 
 
 
 
 
 
 
958dc68
 
dac00a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import numpy
import torch
from PIL import Image
from torch import nn


class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),  # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x: torch.Tensor):
        x = x.view(-1, 28 * 28)
        x = self.model(x)
        return x


model = MNISTModel()
state_dict = torch.load('mnist_model.pth', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()


def predict(image):
    if not image:
        return None
    # convert bg black to white
    image = image['composite']
    # 画像を白黒に変換
    image_gray = image.convert('L')
    # リサイズ
    image_resized = image_gray.resize((28, 28), Image.LANCZOS)
    # 画像をテンソルに変換
    x = torch.tensor(numpy.array(image_resized), dtype=torch.float32)

    with torch.no_grad():
        output = model(x.unsqueeze(0))
        output.argmax(dim=1)
        _, predicted = torch.max(output, 1)

    return predicted.item()


input_component = gr.Sketchpad(type='pil',
                               height=280,
                               width=280,
                               layers=False,
                               image_mode='L',
                               brush=gr.Brush(default_color='auto', color_mode='defaults'))
iface = gr.Interface(fn=predict,
                     inputs=input_component,
                     outputs="text",
                     live=True,
                     clear_btn=gr.ClearButton(visible=False),
                     allow_flagging='never')
iface.launch()