MNIST_DEMO / app.py
TakeshiSaito's picture
update
958dc68
raw
history blame
No virus
1.9 kB
import gradio as gr
import numpy
import torch
from PIL import Image
from torch import nn
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10), # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元
nn.LogSoftmax(dim=1)
)
def forward(self, x: torch.Tensor):
x = x.view(-1, 28 * 28)
x = self.model(x)
return x
model = MNISTModel()
state_dict = torch.load('mnist_model.pth', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(image):
if not image:
return None
# convert bg black to white
image = image['composite']
# 画像を白黒に変換
image_gray = image.convert('L')
# リサイズ
image_resized = image_gray.resize((28, 28), Image.LANCZOS)
# 画像をテンソルに変換
x = torch.tensor(numpy.array(image_resized), dtype=torch.float32)
with torch.no_grad():
output = model(x.unsqueeze(0))
output.argmax(dim=1)
_, predicted = torch.max(output, 1)
return predicted.item()
input_component = gr.Sketchpad(type='pil',
height=280,
width=280,
layers=False,
image_mode='L',
brush=gr.Brush(default_color='auto', color_mode='defaults'))
iface = gr.Interface(fn=predict,
inputs=input_component,
outputs="text",
live=True,
clear_btn=gr.ClearButton(visible=False),
allow_flagging='never')
iface.launch()