Spaces:
Sleeping
Sleeping
import os | |
from collections import OrderedDict | |
import torch | |
from torchvision import datasets, transforms | |
import torch.nn.functional as F | |
import torch. nn as nn | |
from PIL import Image | |
import gradio as gr | |
class Flatten(nn.Module): | |
def forward(self, input): | |
return input.view(input.size(0), -1) | |
class ConvNet(nn.Module): | |
def __init__(self): | |
super(ConvNet, self).__init__() | |
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=1) | |
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1) | |
self.conv3 = nn.Conv2d(64, 10, kernel_size=5, stride=2, padding=1) | |
self.fc = nn.Linear(10,10) | |
self.dropout = nn.Dropout(0.5) | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.flatten = Flatten() | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = self.dropout(x) | |
x = F.relu(self.conv2(x)) | |
x = self.dropout(x) | |
x = F.relu(self.conv3(x)) | |
x = self.dropout(x) | |
x = self.avgpool(x) | |
x = self.flatten(x) | |
x = self.fc(x) | |
return x | |
def predict(img, withGradio=False): | |
if withGradio: | |
img = Image.fromarray(img) | |
transform = transforms.Compose([ | |
transforms.Resize((28, 28)), | |
transforms.Grayscale(), | |
transforms.ToTensor(), | |
]) | |
img_tensor = transform(img) | |
input_data = img_tensor | |
input_data = input_data.unsqueeze(0) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Get the state_dict of conv_net | |
state_dict = torch.load('this-is-mnist-model-f1c-desk-19092023.pth') | |
# Define a new state_dict for ConvNet | |
new_state_dict = OrderedDict() | |
# Manually map the state names from conv_net base model | |
new_state_dict['conv1.weight'] = state_dict['0.weight'] | |
new_state_dict['conv1.bias'] = state_dict['0.bias'] | |
new_state_dict['conv2.weight'] = state_dict['3.weight'] | |
new_state_dict['conv2.bias'] = state_dict['3.bias'] | |
new_state_dict['conv3.weight'] = state_dict['6.weight'] | |
new_state_dict['conv3.bias'] = state_dict['6.bias'] | |
new_state_dict['fc.weight'] = state_dict['11.weight'] | |
new_state_dict['fc.bias'] = state_dict['11.bias'] | |
# Load the new_state_dict into ConvNet | |
model = ConvNet() | |
model.load_state_dict(new_state_dict) | |
model.to(device) | |
model.eval() # Set the model to evaluation mode | |
# Pass the input data to the model | |
with torch.no_grad(): | |
output = model(input_data) | |
# Postprocess the output | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
n_predicted_class = probabilities.argmax().item() | |
return n_predicted_class | |
def wrapper_fn(input_image): | |
return predict(input_image, withGradio=True) | |
# Define Gradio interface | |
title = "MNIST - understanding the basics" | |
description = "I have created and trained a CNN for MNIST. You can find the exercise notebook [here](https://www.kaggle.com/code/mindgspl/exercise-mnist). Note : use same size image as the model 28x28, white text on black for best results. " | |
examples = ['data/0-custom-invert.jpg', 'data/0.jpg', 'data/2.jpg', 'data/3.jpg', 'data/5.jpg', 'data/9.jpg', 'data/0-custom.jpg',] | |
output = gr.Textbox(label="Output prediction") | |
app = gr.Interface(fn=wrapper_fn, inputs=gr.Image(), outputs=output, title=title,description=description,examples=examples) | |
# Launch the app | |
app.launch() |