Fabsi commited on
Commit
3557ad4
1 Parent(s): 82d3f2e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ # Load the pre-trained Pokémon model
7
+ model_path = "pokemon_classifier_model.keras"
8
+ model = tf.keras.models.load_model(model_path)
9
+
10
+ # Define the Pokémon classes
11
+ classes = ['Aerodactyl', 'Alakazam', 'Beedrill'] # Adjust these as per your model's classes
12
+
13
+ # Define the image classification function
14
+ def classify_image(image):
15
+ try:
16
+ # Ensure the image is in RGB and normalize it
17
+ if image.ndim == 2: # Check if the image is grayscale
18
+ image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel
19
+ elif image.shape[2] == 4: # Check if the image has an alpha channel
20
+ image = image[:, :, :3] # Drop the alpha channel
21
+ image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize
22
+ image = image.resize((150, 150)) # Resize to match the model's input size
23
+
24
+ image_array = np.array(image) / 255.0 # Convert to array and normalize
25
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
26
+
27
+ # Predict using the model
28
+ prediction = model.predict(image_array)
29
+ predicted_class = classes[np.argmax(prediction)]
30
+ confidence = np.max(prediction)
31
+
32
+ return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
33
+ except Exception as e:
34
+ return str(e) # Return the error message if something goes wrong
35
+
36
+ # Create Gradio interface
37
+ input_image = gr.Image() # Using Gradio's Image component correctly
38
+ output_label = gr.Label()
39
+
40
+ interface = gr.Interface(fn=classify_image,
41
+ inputs=input_image,
42
+ outputs=output_label,
43
+ examples=["pokemon/aerodactyl.png", "pokemon/alakazam.png", "pokemon/beedrill.png"],
44
+ description="Upload an image of a Pokémon (Aerodactyl, Alakazam or Beedrill) to classify!")
45
+
46
+ interface.launch()