pokemon / app.py
tamvi's picture
Upload app.py
1cc8e4d verified
raw
history blame contribute delete
No virus
1.79 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
model_path = "transferlearning_pokemon.keras"
model = tf.keras.models.load_model(model_path)
# Define the core prediction function
def predict_pokemon(image):
# Preprocess image
print(type(image))
image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
image = image.resize((150, 150)) # Resize the image to 150x150
image = np.array(image)
image = np.expand_dims(image, axis=0) # Expand dimensions to match the model input shape
# Predict
prediction = model.predict(image)
# Print the shape of the prediction to debug
print(f"Prediction shape: {prediction.shape}")
# Assuming the output is already softmax probabilities
probabilities = prediction[0]
# Print the probabilities array to debug
print(f"Probabilities: {probabilities}")
# Assuming your model was trained with these class names
class_names = ['charmander', 'eevee', 'pikachuu'] # Replace 'another_pokemon' with your third class name
# Create a dictionary of class probabilities
result = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
return result
# Create the Gradio interface
input_image = gr.Image()
iface = gr.Interface(
fn=predict_pokemon,
inputs=input_image,
outputs=gr.Label(),
examples=["pokemon_examples/charmander.png", "pokemon_examples/charmander1.jpg", "pokemon_examples/eevee.png", "pokemon_examples/eevee1.jpg", "pokemon_examples/pika.png", "pokemon_examples/pika1.jpg"],
description="A simple mlp classification model for image classification using the mnist dataset.")
iface.launch()