Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
model_path = "xception_aerial.keras" | |
model = tf.keras.models.load_model(model_path) | |
# Define the core prediction function | |
def predict_aerial(image): | |
# Preprocess image | |
print(type(image)) | |
image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image | |
image = image.resize((150, 150)) # Resize the image to 150x150 | |
image = np.array(image) | |
image = np.expand_dims(image, axis=0) # Expand dimensions to match the model input shape | |
# Predict | |
prediction = model.predict(image) | |
# Print the shape of the prediction to debug | |
print(f"Prediction shape: {prediction.shape}") | |
# Assuming the output is already softmax probabilities | |
probabilities = prediction[0] | |
# Print the probabilities array to debug | |
print(f"Probabilities: {probabilities}") | |
# Assuming your model was trained with these class names | |
class_names = ['agriculture', 'airport', 'beach', 'city', 'forest'] # Replace 'another_pokemon' with your third class name | |
# Create a dictionary of class probabilities | |
result = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} | |
return result | |
# Create the Gradio interface | |
input_image = gr.Image() | |
iface = gr.Interface( | |
fn=predict_aerial, | |
inputs=input_image, | |
outputs=gr.Label(), | |
examples=["aerial_examples/agriculture1.jpg", "aerial_examples/agriculture2.jpg", "aerial_examples/agriculture3.jpg", "aerial_examples/airport1.jpg", "aerial_examples/airport2.jpg", "aerial_examples/airport3.jpg", "aerial_examples/beach1.jpg", "aerial_examples/beach2.jpg", "aerial_examples/beach3.jpg", "aerial_examples/forest1.jpg", "aerial_examples/forest2.jpg", "aerial_examples/forest3.jpg", | |
"aerial_examples/city1.jpg", "aerial_examples/city2.jpg", "aerial_examples/city3.jpg"], | |
description="A simple mlp classification model for image classification using the mnist dataset.") | |
iface.launch() |