handwritting / app.py
dschandra's picture
Update app.py
d37169c verified
raw
history blame
1.74 kB
# Import necessary libraries
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
# Define the neural network model using PyTorch
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28) # Flatten the input
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
# Initialize the model and load the trained weights
model = Net()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((28, 28)), # Resize image to 28x28
transforms.Grayscale(), # Convert to grayscale
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize
])
# Define the prediction function
def predict_image(img):
img = transform(img) # Apply transformations
img = img.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(img)
predicted_digit = output.argmax(dim=1).item()
return predicted_digit
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False),
outputs='label',
live=True,
description="Upload an image of a handwritten digit, and the model will predict the digit."
)
# Launch the Gradio interface
if __name__ == '__main__':
iface.launch()