Spaces:
Build error
Build error
import torch | |
import gradio as gr | |
from PIL import Image | |
from torch import nn | |
from torchvision import transforms | |
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super(NeuralNetwork, self).__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(28 * 28, 784), | |
nn.ReLU(), | |
nn.Linear(784, 784), | |
nn.ReLU(), | |
nn.Linear(784, 10) | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
model = NeuralNetwork() | |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
def image_classifier(img_input): | |
img = Image.fromarray(img_input.astype('uint8'), 'RGB') | |
img = transforms.ToTensor()(img) | |
with torch.no_grad(): | |
pred = model(img)[0] | |
pred = torch.nn.functional.softmax(pred) | |
return {classes[i]: float(pred[i]) for i in range(10)} | |
gr.Interface(fn=image_classifier, | |
inputs=gr.Image(shape=(28, 28)), | |
outputs=gr.Label(num_top_classes=4), | |
examples=["mnist_0.png", "mnist_2.png", "mnist_3.png"]).launch() | |