import torch import gradio as gr from PIL import Image from torch import nn from torchvision import transforms classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 784), nn.ReLU(), nn.Linear(784, 784), nn.ReLU(), nn.Linear(784, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork() model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu'))) model.eval() def image_classifier(img_input): img = Image.fromarray(img_input.astype('uint8'), 'RGB') img = transforms.ToTensor()(img) with torch.no_grad(): pred = model(img)[0] pred = torch.nn.functional.softmax(pred) return {classes[i]: float(pred[i]) for i in range(10)} gr.Interface(fn=image_classifier, inputs=gr.Image(shape=(28, 28)), outputs=gr.Label(num_top_classes=4), examples=["mnist_0.png", "mnist_2.png", "mnist_3.png"]).launch()