File size: 1,990 Bytes
9c488ae
9c1a58a
fb48e9e
9c1a58a
 
 
 
 
868bd7c
9c1a58a
dc97ffe
0af85f8
9c1a58a
9c488ae
fb48e9e
dc97ffe
 
 
 
 
 
 
 
 
9c1a58a
 
 
fb48e9e
 
 
9c1a58a
6dbf2f2
fb48e9e
f1948ae
9f5da3c
9c1a58a
dc97ffe
a80e9f3
f1948ae
9c1a58a
 
 
1d489e4
e3efc57
9f5da3c
dc97ffe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import tensorflow as tf
from PIL import Image
import numpy as np

# Load the pre-trained Pokémon model
model_path = "pokemon_classifier_model.keras"
model = tf.keras.models.load_model(model_path)

# Define the Pokémon classes
classes = ['Doduo', 'Geodude', 'Zubat']  # Adjust these as per your model's classes

# Define the image classification function
def classify_image(image):
    try:
        # Ensure the image is in RGB and normalize it
        if image.ndim == 2:  # Check if the image is grayscale
            image = np.stack((image,)*3, axis=-1)  # Convert grayscale to RGB by repeating the gray channel
        elif image.shape[2] == 4:  # Check if the image has an alpha channel
            image = image[:, :, :3]  # Drop the alpha channel
        image = Image.fromarray(image.astype('uint8'), 'RGB')  # Convert to PIL Image to resize
        image = image.resize((150, 150))  # Resize to match the model's input size

        image_array = np.array(image) / 255.0  # Convert to array and normalize
        image_array = np.expand_dims(image_array, axis=0)  # Add batch dimension

        # Predict using the model
        prediction = model.predict(image_array)
        predicted_class = classes[np.argmax(prediction)]
        confidence = np.max(prediction)

        return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
    except Exception as e:
        return str(e)  # Return the error message if something goes wrong

# Create Gradio interface
input_image = gr.Image()  # Using Gradio's Image component correctly
output_label = gr.Label()

interface = gr.Interface(fn=classify_image,
                         inputs=input_image,
                         outputs=output_label,
                         examples=["pokemon/doduo.png", "pokemon/geodude.png", "pokemon/zubat.png"],
                         description="Upload an image of a Pokémon (Doduo, Geodude or Zubat) to classify!")

interface.launch()