Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ # Load the TorchScript model (make sure to place 'lenet_cnn.pth' in your repository or accessible path)
8
+ model = torch.jit.load("lenet_cnn.pth")
9
+ model.eval()
10
+
11
+ # Set the device (here we assume CPU, adjust if needed)
12
+ device = torch.device("cpu")
13
+
14
+ def predict(data):
15
+ try:
16
+ # Extract the drawn image from the input
17
+ image = data["composite"]
18
+ if image is None or np.sum(image) == 0:
19
+ return "Error: No strokes detected. Please draw a digit."
20
+
21
+ # Convert to grayscale using the alpha channel and resize to 28x28
22
+ image = Image.fromarray(image[:, :, 3])
23
+ image = image.resize((28, 28)).convert("L")
24
+
25
+ # Normalize and convert to tensor
26
+ image = np.array(image, dtype=np.float32) / 255.0
27
+ image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
28
+ image = image.to(device)
29
+
30
+ # Run the model inference
31
+ with torch.no_grad():
32
+ output = model(image)
33
+ probabilities = F.softmax(output, dim=1).squeeze(0).tolist()
34
+
35
+ # Create a dictionary mapping digit to probability (as a percentage)
36
+ result = {str(i): prob * 100 for i, prob in enumerate(probabilities)}
37
+ return result
38
+
39
+ except Exception as e:
40
+ return f"Error: {str(e)}"
41
+
42
+ # Create the Gradio Interface
43
+ interface = gr.Interface(
44
+ fn=predict,
45
+ inputs=gr.Sketchpad(width=560, height=560, brush=gr.Brush(line_width=25)), # Using "line_width" to adjust brush size
46
+ outputs=gr.Label(num_top_classes=3),
47
+ title="LeNet Handwritten Digit Classifier",
48
+ description="Draw a digit and press 'Submit' to classify it.",
49
+ theme="dark"
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ interface.launch(share=True)