File size: 3,434 Bytes
c863d81
 
 
 
 
3e0efdf
c863d81
a24629d
c863d81
 
 
 
d0f3245
 
c863d81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Importing necessary libraries
import numpy as np                          # For handling arrays and numerical operations
import gradio as gr                         # For creating a simple web interface for interacting with the model
from PIL import Image                       # For image processing (like resizing)
from tensorflow import keras                 # For building and working with machine learning models

# Building the model using Keras Sequential API
model = keras.models.Sequential([
  keras.layers.Flatten(input_shape=(28, 28)),      # This layer flattens the 28x28 pixel image into a 1D array of 784 values
  keras.layers.Dense(512, activation='relu'),      # This fully connected layer has 512 neurons with ReLU activation function
  keras.layers.Dense(512, activation='relu'),      # Another fully connected layer with 512 neurons and ReLU activation
  keras.layers.Dense(10, activation='softmax')     # The output layer has 10 neurons for the 10 digits (0-9) with softmax activation to convert raw scores into probabilities
])

# Compiling the model with Adam optimizer and sparse categorical cross-entropy loss function
model.compile(optimizer=keras.optimizers.Adam(0.001),  # Adam optimizer with learning rate of 0.001
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # Using cross-entropy loss for multi-class classification
              metrics=[keras.metrics.SparseCategoricalAccuracy()])  # Measuring accuracy during training

# Loading pre-trained weights for the model
model.load_weights('./weights/mnist.weights.h5')  # Loading the weights from a saved file

# Defining the function that will be used to classify the input image
def classify(input):
    # Preprocessing the input image
    # Convert the input image (which is in the format of a list of pixel values) to a numpy array
    # Resize it to 28x28 pixels (if not already 28x28)
    image = np.expand_dims(np.array(Image.fromarray(input['layers'][0])  # 'input' contains the image data
                                    .resize((28, 28), resample=Image.Resampling.BILINEAR), dtype=int), axis=0)  # Resizing to match model input size

    # Predicting the digit from the processed image using the model
    prediction = model.predict(image).tolist()[0]  # Getting the output prediction as a list

    # Returning the probabilities for each of the 10 digits (0 to 9)
    return {str(i): float(prediction[i]) for i in range(10)}  # Converting the predictions into a dictionary of probabilities

# Setting up the Gradio interface for user interaction
# The user will draw a digit on the sketchpad, which will be classified
input_sketchpad = gr.Paint(image_mode="L", brush=gr.components.image_editor.Brush(default_color="rgb(156, 104, 200)"))  # The input is a paint canvas where the user can draw
output_label = gr.Label()  # A label to display the predicted output (probabilities of each digit)

# Creating and launching the Gradio interface
gr.Interface(fn=classify,                # The function that will handle the classification
             inputs=input_sketchpad,      # The input is the paint canvas where the user draws the digit
             outputs=output_label,       # The output will display the predicted label
             flagging_mode='never',      # Disable flagging for the interface
             theme=gr.themes.Soft()).launch()  # Using a soft theme for the interface and launching it