pokomon / app.py
membrfab's picture
Update app.py
a0733d5 verified
raw
history blame contribute delete
No virus
1.33 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import keras as keras
# Load your custom classification model
model = tf.keras.models.load_model('pokemon_classifier_model.keras')
# Class names, should match your dataset
class_names = ['Gengar', 'Pikachu', 'Scyther']
def classify_image(image):
# Convert the Gradio input image to a PIL image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Resize the image using np.resize
image = np.resize(image, (224, 224, 3)) # Add the channel dimension
image = image / 255.0 # Normalize the image
image = np.expand_dims(image, axis=0) # Add batch dimension
# Predict the class of the image
predictions = model.predict(image)
predicted_class = class_names[np.argmax(predictions)]
confidence = np.max(predictions)
return {predicted_class: float(confidence)}
image_input = gr.Image()
label = gr.Label(num_top_classes=3)
interface = gr.Interface(
fn=classify_image,
inputs=image_input,
outputs=label,
title='Pokémon Classifier',
description='Upload an image of Pikachu, Gengar, or Scyther, and the classifier will tell you which Pokémon it is, along with the confidence level of the prediction.'
)
interface.launch()