import gradio as gr import numpy import torch from PIL import Image from torch import nn class MNISTModel(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 10), # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元 nn.LogSoftmax(dim=1) ) def forward(self, x: torch.Tensor): x = x.view(-1, 28 * 28) x = self.model(x) return x model = MNISTModel() state_dict = torch.load('mnist_model.pth', map_location='cpu') model.load_state_dict(state_dict, strict=False) model.eval() def predict(image): if not image: return None # convert bg black to white image = image['composite'] # 画像を白黒に変換 image_gray = image.convert('L') # リサイズ image_resized = image_gray.resize((28, 28), Image.LANCZOS) # 画像をテンソルに変換 x = torch.tensor(numpy.array(image_resized), dtype=torch.float32) with torch.no_grad(): output = model(x.unsqueeze(0)) output.argmax(dim=1) _, predicted = torch.max(output, 1) return predicted.item() input_component = gr.Sketchpad(type='pil', height=280, width=280, layers=False, image_mode='L', brush=gr.Brush(default_color='auto', color_mode='defaults')) iface = gr.Interface(fn=predict, inputs=input_component, outputs="text", live=True, clear_btn=gr.ClearButton(visible=False), allow_flagging='never') iface.launch()