import os from collections import OrderedDict import torch from torchvision import datasets, transforms import torch.nn.functional as F import torch. nn as nn from PIL import Image import gradio as gr class Flatten(nn.Module): def forward(self, input): return input.view(input.size(0), -1) class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=1) self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1) self.conv3 = nn.Conv2d(64, 10, kernel_size=5, stride=2, padding=1) self.fc = nn.Linear(10,10) self.dropout = nn.Dropout(0.5) self.avgpool = nn.AdaptiveAvgPool2d(1) self.flatten = Flatten() def forward(self, x): x = F.relu(self.conv1(x)) x = self.dropout(x) x = F.relu(self.conv2(x)) x = self.dropout(x) x = F.relu(self.conv3(x)) x = self.dropout(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x def predict(img, withGradio=False): if withGradio: img = Image.fromarray(img) transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(), transforms.ToTensor(), ]) img_tensor = transform(img) input_data = img_tensor input_data = input_data.unsqueeze(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get the state_dict of conv_net state_dict = torch.load('this-is-mnist-model-f1c-desk-19092023.pth') # Define a new state_dict for ConvNet new_state_dict = OrderedDict() # Manually map the state names from conv_net base model new_state_dict['conv1.weight'] = state_dict['0.weight'] new_state_dict['conv1.bias'] = state_dict['0.bias'] new_state_dict['conv2.weight'] = state_dict['3.weight'] new_state_dict['conv2.bias'] = state_dict['3.bias'] new_state_dict['conv3.weight'] = state_dict['6.weight'] new_state_dict['conv3.bias'] = state_dict['6.bias'] new_state_dict['fc.weight'] = state_dict['11.weight'] new_state_dict['fc.bias'] = state_dict['11.bias'] # Load the new_state_dict into ConvNet model = ConvNet() model.load_state_dict(new_state_dict) model.to(device) model.eval() # Set the model to evaluation mode # Pass the input data to the model with torch.no_grad(): output = model(input_data) # Postprocess the output probabilities = torch.nn.functional.softmax(output[0], dim=0) n_predicted_class = probabilities.argmax().item() return n_predicted_class def wrapper_fn(input_image): return predict(input_image, withGradio=True) # Define Gradio interface title = "MNIST - understanding the basics" description = "I have created and trained a CNN for MNIST. You can find the exercise notebook [here](https://www.kaggle.com/code/mindgspl/exercise-mnist). Note : use same size image as the model 28x28, white text on black for best results. " examples = ['data/0-custom-invert.jpg', 'data/0.jpg', 'data/2.jpg', 'data/3.jpg', 'data/5.jpg', 'data/9.jpg', 'data/0-custom.jpg',] output = gr.Textbox(label="Output prediction") app = gr.Interface(fn=wrapper_fn, inputs=gr.Image(), outputs=output, title=title,description=description,examples=examples) # Launch the app app.launch()