File size: 3,405 Bytes
b4add49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cff1408
 
 
b4add49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cff1408
b4add49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cff1408
 
7c17e4d
b4add49
7c17e4d
b4add49
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from collections import OrderedDict
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch. nn as nn
from PIL import Image
import gradio as gr

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 10, kernel_size=5, stride=2, padding=1)
        self.fc = nn.Linear(10,10)
        self.dropout = nn.Dropout(0.5)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = Flatten()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = F.relu(self.conv3(x))
        x = self.dropout(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

def predict(img, withGradio=False):
    if withGradio:
        img = Image.fromarray(img)
    
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
    ])

    img_tensor = transform(img)
    input_data = img_tensor
    input_data = input_data.unsqueeze(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get the state_dict of conv_net
    state_dict = torch.load('this-is-mnist-model-f1c-desk-19092023.pth')

    # Define a new state_dict for ConvNet
    new_state_dict = OrderedDict()

    # Manually map the state names from conv_net base model
    new_state_dict['conv1.weight'] = state_dict['0.weight']
    new_state_dict['conv1.bias'] = state_dict['0.bias']
    new_state_dict['conv2.weight'] = state_dict['3.weight']
    new_state_dict['conv2.bias'] = state_dict['3.bias']
    new_state_dict['conv3.weight'] = state_dict['6.weight']
    new_state_dict['conv3.bias'] = state_dict['6.bias']
    new_state_dict['fc.weight'] = state_dict['11.weight']
    new_state_dict['fc.bias'] = state_dict['11.bias']

    # Load the new_state_dict into ConvNet
    model = ConvNet()
    model.load_state_dict(new_state_dict)
    
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    
    # Pass the input data to the model
    with torch.no_grad():
        output = model(input_data)

    # Postprocess the output
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    n_predicted_class = probabilities.argmax().item()
    return n_predicted_class


def wrapper_fn(input_image):
    return predict(input_image, withGradio=True)

# Define Gradio interface
title = "MNIST - understanding the basics"
description = "I have created and trained a CNN for MNIST. You can find the exercise notebook [here](https://www.kaggle.com/code/mindgspl/exercise-mnist). Note : use same size image as the model 28x28, white text on black for best results. "
examples = ['data/0-custom-invert.jpg', 'data/0.jpg', 'data/2.jpg', 'data/3.jpg',  'data/5.jpg',  'data/9.jpg', 'data/0-custom.jpg',]
output = gr.Textbox(label="Output prediction")
app = gr.Interface(fn=wrapper_fn, inputs=gr.Image(), outputs=output, title=title,description=description,examples=examples)

# Launch the app
app.launch()