Manoharmata's picture
Upload 2 files
3649097 verified
Raw
History Blame Contribute Delete
2.26 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
class BestMNISTCNN(nn.Module):
def __init__(self):
super().__init__()
self.convblock1 = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.1)
)
self.convblock2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.1)
)
self.convblock3 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1)),
nn.Dropout(0.2)
)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.convblock1(x)
x = self.convblock2(x)
x = self.convblock3(x)
x = x.view(x.size(0), -1)
return self.fc(x)
model = BestMNISTCNN()
model.load_state_dict(torch.load("mnist_cnn_.pth", map_location="cpu"))
model.eval()
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
classes = [str(i) for i in range(10)]
def predict(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(image)
probs = torch.softmax(logits, dim=1)[0]
return {classes[i]: float(probs[i]) for i in range(10)}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload a Digit"),
outputs=gr.Label(num_top_classes=3),
title="MNIST CNN Classifier (20 Epochs)",
description="Upload a digit image to classify using the best CNN model trained for 20 epochs.",
)
demo.launch()