hindi-mnist / app.py
bhaveshgoel07's picture
Fixed errors
7355da1
raw
history blame
1.83 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image
# Define the CNN
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x))) # Output: 32x14x14
x = self.pool(self.relu(self.conv2(x))) # Output: 64x7x7
x = x.view(-1, 64 * 7 * 7) # Flattened to: 3136
x = self.relu(self.fc1(x)) # Output: 128
x = self.fc2(x) # Output: 10 logits
return x
# Load the trained model
model = SimpleCNN()
model.load_state_dict(torch.load('mnist_cnn.pth', map_location=torch.device('cpu')))
model.eval()
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Prediction function
# Prediction function
def predict(image):
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
probabilities = nn.Softmax(dim=1)(output)
predicted_class = torch.argmax(probabilities, dim=1)
return {str(i): probabilities[0][i].item() for i in range(10)}
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(type='pil'),
outputs=gr.Label(num_top_classes=10)
)
# Launch the interface
interface.launch()