Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy | |
import torch | |
from PIL import Image | |
from torch import nn | |
class MNISTModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Linear(28 * 28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 10), # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元 | |
nn.LogSoftmax(dim=1) | |
) | |
def forward(self, x: torch.Tensor): | |
x = x.view(-1, 28 * 28) | |
x = self.model(x) | |
return x | |
model = MNISTModel() | |
state_dict = torch.load('mnist_model.pth', map_location='cpu') | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
def predict(image): | |
if not image: | |
return None | |
# convert bg black to white | |
image = image['composite'] | |
# 画像を白黒に変換 | |
image_gray = image.convert('L') | |
# リサイズ | |
image_resized = image_gray.resize((28, 28), Image.LANCZOS) | |
# 画像をテンソルに変換 | |
x = torch.tensor(numpy.array(image_resized), dtype=torch.float32) | |
with torch.no_grad(): | |
output = model(x.unsqueeze(0)) | |
output.argmax(dim=1) | |
_, predicted = torch.max(output, 1) | |
return predicted.item() | |
input_component = gr.Sketchpad(type='pil', | |
height=280, | |
width=280, | |
layers=False, | |
image_mode='L', | |
brush=gr.Brush(default_color='auto', color_mode='defaults')) | |
iface = gr.Interface(fn=predict, | |
inputs=input_component, | |
outputs="text", | |
live=True, | |
clear_btn=gr.ClearButton(visible=False)) | |
iface.launch() | |