pokemon2 / app.py
kramesab's picture
Upload 2 files
8aae1ba verified
raw
history blame contribute delete
No virus
1.83 kB
# %%
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.resnet50 import preprocess_input
from PIL import Image
# Load the pre-trained Keras model
model = load_model('pokemon-model.keras')
# Define the class labels
class_labels = ['Bulbasaur', 'Glumanda', 'Pikachu'] # Ensure this matches the training order
# Define the image processing and prediction function
def predict_image(img):
# Ensure the image is a PIL image
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
# Resize the image to the size expected by ResNet50
img = img.resize((224, 224))
# Convert the image to a numpy array
img_array = np.array(img)
# Convert the image array to a batch of size 1 (1, 224, 224, 3)
img_array = np.expand_dims(img_array, axis=0)
# Preprocess the image array using ResNet50's preprocessing
img_array = preprocess_input(img_array)
# Make prediction
prediction = model.predict(img_array)
# Get the label with the highest probability
predicted_index = int(np.argmax(prediction))
predicted_label = class_labels[predicted_index]
return predicted_label
# Create the Gradio interface with multiple examples
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(image_mode='RGB'),
outputs='label',
examples=[['00000015.jpg'], ['20.png'], ['glumanda.jpg'], ['j67j7.png'], ['pikachu.jpg']],
title="Pokémon Classification",
description="Upload an image of a Pokémon to classify it using the pre-trained model."
)
# Launch the interface inline in the Jupyter Notebook
iface.launch(inline=True)
# %%
# Print model summary to verify input shape
print(model.summary())