mnist_basic / app.py
mgspl's picture
Update app.py with new model , update desc
cff1408
raw
history blame contribute delete
No virus
3.41 kB
import os
from collections import OrderedDict
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch. nn as nn
from PIL import Image
import gradio as gr
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 10, kernel_size=5, stride=2, padding=1)
self.fc = nn.Linear(10,10)
self.dropout = nn.Dropout(0.5)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.flatten = Flatten()
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.dropout(x)
x = F.relu(self.conv2(x))
x = self.dropout(x)
x = F.relu(self.conv3(x))
x = self.dropout(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
def predict(img, withGradio=False):
if withGradio:
img = Image.fromarray(img)
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
])
img_tensor = transform(img)
input_data = img_tensor
input_data = input_data.unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get the state_dict of conv_net
state_dict = torch.load('this-is-mnist-model-f1c-desk-19092023.pth')
# Define a new state_dict for ConvNet
new_state_dict = OrderedDict()
# Manually map the state names from conv_net base model
new_state_dict['conv1.weight'] = state_dict['0.weight']
new_state_dict['conv1.bias'] = state_dict['0.bias']
new_state_dict['conv2.weight'] = state_dict['3.weight']
new_state_dict['conv2.bias'] = state_dict['3.bias']
new_state_dict['conv3.weight'] = state_dict['6.weight']
new_state_dict['conv3.bias'] = state_dict['6.bias']
new_state_dict['fc.weight'] = state_dict['11.weight']
new_state_dict['fc.bias'] = state_dict['11.bias']
# Load the new_state_dict into ConvNet
model = ConvNet()
model.load_state_dict(new_state_dict)
model.to(device)
model.eval() # Set the model to evaluation mode
# Pass the input data to the model
with torch.no_grad():
output = model(input_data)
# Postprocess the output
probabilities = torch.nn.functional.softmax(output[0], dim=0)
n_predicted_class = probabilities.argmax().item()
return n_predicted_class
def wrapper_fn(input_image):
return predict(input_image, withGradio=True)
# Define Gradio interface
title = "MNIST - understanding the basics"
description = "I have created and trained a CNN for MNIST. You can find the exercise notebook [here](https://www.kaggle.com/code/mindgspl/exercise-mnist). Note : use same size image as the model 28x28, white text on black for best results. "
examples = ['data/0-custom-invert.jpg', 'data/0.jpg', 'data/2.jpg', 'data/3.jpg', 'data/5.jpg', 'data/9.jpg', 'data/0-custom.jpg',]
output = gr.Textbox(label="Output prediction")
app = gr.Interface(fn=wrapper_fn, inputs=gr.Image(), outputs=output, title=title,description=description,examples=examples)
# Launch the app
app.launch()