File size: 1,736 Bytes
fa9b186
 
 
 
 
 
0886b2c
fa9b186
 
985e70c
fa9b186
 
 
 
 
 
 
 
d37169c
fa9b186
 
 
 
 
0886b2c
fa9b186
 
0886b2c
 
 
 
 
 
 
 
 
fa9b186
985e70c
fa9b186
0886b2c
 
fa9b186
0886b2c
 
fa9b186
 
 
 
0886b2c
 
 
fa9b186
 
 
 
985e70c
fa9b186
d219cb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Import necessary libraries
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

# Define the neural network model using PyTorch
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# Initialize the model and load the trained weights
model = Net()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # Resize image to 28x28
    transforms.Grayscale(),       # Convert to grayscale
    transforms.ToTensor(),        # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize
])

# Define the prediction function
def predict_image(img):
    img = transform(img)  # Apply transformations
    img = img.unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(img)
        predicted_digit = output.argmax(dim=1).item()
    return predicted_digit

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False),
    outputs='label',
    live=True,
    description="Upload an image of a handwritten digit, and the model will predict the digit."
)

# Launch the Gradio interface
if __name__ == '__main__':
    iface.launch()