OnePiece / app.py
Kuwegjer's picture
Update app.py
529fdab verified
raw
history blame contribute delete
No virus
1.53 kB
import gradio as gr
from PIL import Image
import numpy as np
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import load_model
# Load your trained model
model = load_model('/home/user/app/mein_modell.h5')
def predict_character(img):
img = Image.fromarray(img.astype('uint8'), 'RGB') # Ensure the image is in RGB
img = img.resize((224, 224)) # Resize the image to the input size of the model
img_array = keras_image.img_to_array(img) # Convert the image to an array
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions to match model input
img_array = preprocess_input(img_array) # Preprocess the input as expected by ResNet50
prediction = model.predict(img_array) # Predict using the model
classes = ['Chopper', 'Nami', 'Ruffy', 'Sanji', 'Usopp', 'Zoro'] # Character names as per your dataset
return {classes[i]: float(prediction[0][i]) for i in range(len(classes))} # Return the prediction in a dictionary format
# Define Gradio interface
interface = gr.Interface(
fn=predict_character,
inputs=gr.Image(), # Gradio handles resizing automatically based on the model input
outputs=gr.Label(num_top_classes=6), # Show top 3 predictions
title="One Piece Character Classifier",
description="Upload an image of a One Piece character and the classifier will predict which character it is."
)
# Launch the interface
interface.launch()