| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
|
|
| class BestMNISTCNN(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| self.convblock1 = nn.Sequential( |
| nn.Conv2d(1, 32, 3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.Conv2d(32, 32, 3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Dropout(0.1) |
| ) |
|
|
| self.convblock2 = nn.Sequential( |
| nn.Conv2d(32, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Dropout(0.1) |
| ) |
|
|
| self.convblock3 = nn.Sequential( |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Conv2d(128, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d((1,1)), |
| nn.Dropout(0.2) |
| ) |
|
|
| self.fc = nn.Linear(128, 10) |
|
|
| def forward(self, x): |
| x = self.convblock1(x) |
| x = self.convblock2(x) |
| x = self.convblock3(x) |
| x = x.view(x.size(0), -1) |
| return self.fc(x) |
|
|
| model = BestMNISTCNN() |
| model.load_state_dict(torch.load("mnist_cnn_.pth", map_location="cpu")) |
| model.eval() |
|
|
| transform = transforms.Compose([ |
| transforms.Grayscale(num_output_channels=1), |
| transforms.Resize((28, 28)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
|
|
| classes = [str(i) for i in range(10)] |
|
|
| def predict(image): |
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| logits = model(image) |
| probs = torch.softmax(logits, dim=1)[0] |
|
|
| return {classes[i]: float(probs[i]) for i in range(10)} |
|
|
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Upload a Digit"), |
| outputs=gr.Label(num_top_classes=3), |
| title="MNIST CNN Classifier (20 Epochs)", |
| description="Upload a digit image to classify using the best CNN model trained for 20 epochs.", |
| ) |
|
|
| demo.launch() |