membrfab's picture
Update app.py
f0d4b98 verified
raw
history blame contribute delete
No virus
1.67 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
# Load the custom classification models
transfer_learning_model = tf.keras.models.load_model('model_vgg16.keras')
# Class names
class_names = ['butterfly', 'cat', 'elephant', 'horse', 'squirrel']
def classify_image(image, model):
# Convert the Gradio input image to a PIL image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Resize the image using np.resize
image = np.resize(image, (300, 300, 3)) # Add the channel dimension
image = image / 255.0 # Normalize the image
image = np.expand_dims(image, axis=0) # Add batch dimension
# Predict the class of the image
predictions = model.predict(image)
# Get the indices of the top 3 predictions
top_indices = np.argsort(predictions[0])[::-1][:3]
# Get the corresponding class names and confidences
top_classes = [class_names[i] for i in top_indices]
confidences = [predictions[0][i] for i in top_indices]
return {class_name: float(confidence) for class_name, confidence in zip(top_classes, confidences)}
image_input = gr.Image()
label = gr.Label(num_top_classes=3)
transfer_learning_interface = gr.Interface(
fn=lambda image: classify_image(image, transfer_learning_model),
inputs=image_input,
outputs=label,
title='Animal Classifier',
description='Upload an image of a butterfly, a cat, an elephant, a horse or a squirrel, and the classifier will tell you which animal it is, along with the confidence level of the prediction.'
)
transfer_learning_interface.launch()