alibayram commited on
Commit
3646319
·
1 Parent(s): ade9637

Refactor app.py: enhance prediction function, update title and description, and improve image handling

Browse files
Files changed (1) hide show
  1. app.py +57 -32
app.py CHANGED
@@ -1,47 +1,72 @@
1
- # import dependencies
2
  import gradio as gr
3
  import tensorflow as tf
4
  import cv2
5
 
6
- # app title
7
- title = "Welcome on your first sketch recognition app!"
8
 
9
- # app description
10
- head = (
11
- "<center>"
12
- "<img src='./mnist-classes.png' width=400>"
13
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
- "</center>"
15
- )
16
 
17
- # GitHub repository link
18
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
 
 
19
 
20
- # image size: 28x28
21
- img_size = 28
 
22
 
23
- # classes name (from 0 to 9)
24
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
25
 
26
- # load model (trained on MNIST dataset)
27
- model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
 
 
 
28
 
29
- # prediction function for sketch recognition
30
- def predict(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # image shape: 28x28x1
33
- img = cv2.resize(img, (img_size, img_size))
34
- img = img.reshape(1, img_size, img_size, 1)
 
35
 
36
- # model predictions
37
- preds = model.predict(img)[0]
 
38
 
39
- # return the probability for each classe
40
- return {label: float(pred) for label, pred in zip(labels, preds)}
41
 
42
- # top 3 of classes
43
- label = gr.outputs.Label(num_top_classes=3)
44
 
45
- # open Gradio interface for sketch recognition
46
- interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
47
- interface.launch(share=True)
 
1
+ import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
4
  import cv2
5
 
6
+ # Load the trained MNIST model
7
+ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
8
 
9
+ # Class names (0 to 9)
10
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
 
 
 
 
11
 
12
+ def predict(data):
13
+ # Extract the 'composite' key from the input dictionary
14
+ img = data["composite"]
15
+ img = np.array(img)
16
 
17
+ # Convert RGBA to RGB if needed
18
+ if img.shape[-1] == 4: # RGBA
19
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
20
 
21
+ # Convert RGB to Grayscale
22
+ if img.shape[-1] == 3: # RGB
23
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
24
 
25
+ # Resize image to 28x28
26
+ img = cv2.resize(img, (28, 28))
27
+
28
+ # Normalize pixel values to [0, 1]
29
+ img = img / 255.0
30
 
31
+ # Reshape to match model input (1, 28, 28, 1)
32
+ img = img.reshape(1, 28, 28, 1)
33
+
34
+ # Model predictions
35
+ preds = model.predict(img)[0]
36
+
37
+ # Get top 3 classes
38
+ top_3_classes = np.argsort(preds)[-3:][::-1]
39
+ top_3_probs = preds[top_3_classes]
40
+ class_names = [labels[i] for i in top_3_classes]
41
+
42
+ # Return top 3 predictions as a dictionary
43
+ return {class_names[i]: float(top_3_probs[i]) for i in range(3)}
44
+
45
+ # Title and description
46
+ title = "Welcome to your first sketch recognition app!"
47
+ head = (
48
+ "<center>"
49
+ "<img src='./mnist-classes.png' width=400>"
50
+ "<p>The model is trained to classify numbers (from 0 to 9). "
51
+ "To test it, draw your number in the space provided (use the editing tools in the image editor).</p>"
52
+ "</center>"
53
+ )
54
+ ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
55
 
56
+ with gr.Blocks(title=title) as demo:
57
+ # Display title and description
58
+ gr.Markdown(head)
59
+ gr.Markdown(ref)
60
 
61
+ with gr.Row():
62
+ # Using ImageEditor with type='numpy'
63
+ im = gr.ImageEditor(type="numpy", label="Draw your digit here (use brush and eraser)")
64
 
65
+ # Output label (top 3 predictions)
66
+ label = gr.Label(num_top_classes=3, label="Predictions")
67
 
68
+ # Trigger prediction whenever the image changes
69
+ im.change(predict, inputs=im, outputs=label, show_progress="hidden")
70
 
71
+ if __name__ == "__main__":
72
+ demo.launch(share=True)