amosfang's picture
Update app.py
68f59bb verified
raw
history blame
No virus
4.53 kB
import numpy as np
from PIL import Image
from skimage.transform import resize
import tensorflow as tf
from tensorflow.keras.models import load_model
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import gradio as gr
import os
import io
REPO_ID = "amosfang/segmentation_u_net"
def pil_image_as_numpy_array(pilimg):
img_array = tf.keras.utils.img_to_array(pilimg)
return img_array
def resize_image(image, input_shape=(224, 224, 3)):
# Convert to NumPy array and normalize
image_array = pil_image_as_numpy_array(image)
image = image_array.astype(np.float32) / 255.
# Resize the image to 224x224
image_resized = resize(image, input_shape, anti_aliasing=True)
return image_resized
def load_model_file(filename):
model_dir = snapshot_download(REPO_ID)
saved_model_filepath = os.path.join(model_dir, filename)
unet_model = load_model(saved_model_filepath)
return unet_model
def ensemble_predict(X_array):
#
# Call the predict methods of the unet_model and the vgg16_unet_model
# to retrieve their predictions.
#
# Sum the two predictions together and return their results.
# You can also consider multiplying a different weight on
# one or both of the models to improve performance
X_array = np.expand_dims(X_array, axis=0)
unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
pred_y_unet = unet_model.predict(X_array)
pred_y_vgg16 = vgg16_model.predict(X_array)
pred_y_resnet50 = resnet50_model.predict(X_array)
return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3
def get_predictions(y_prediction_encoded):
# Convert predictions to categorical indices
predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1
return predicted_label_indices
def predict(image):
# Steps to get prediction
sample_image_resized = resize_image(image)
y_pred = ensemble_predict(sample_image_resized)
y_pred = get_predictions(y_pred).squeeze()
# Define your custom colors for each label
colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
# Create a ListedColormap
cmap = ListedColormap(colors)
# Create a figure
fig, ax = plt.subplots()
# Display the image
ax.imshow(sample_image_resized)
# Display the predictions using the specified colormap
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7, alpha=0.5)
# Create colorbar and set ticks and ticklabels
cbar = plt.colorbar(cax, ticks=np.arange(1, 8))
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
# Convert the figure to a PIL Image
image_buffer = io.BytesIO()
plt.savefig(image_buffer, format='png')
image_buffer.seek(0)
image_pil = Image.open(image_buffer)
# Close the figure to release resources
plt.close(fig)
return image_pil
# Specify paths to example images
sample_images = [['989953_sat.jpg'], ['999380_sat.jpg'], ['988205_sat.jpg']]
# Launch Gradio Interface
gr.Interface(
predict,
title='Land Cover Segmentation',
description=
'''
The DeepGlobe Land Cover Classification Challenge offers the first public dataset containing high resolution
satellite imagery focusing on rural areas. As there are multiple land cover types and high density of annotations,
this dataset is more challenging than its counterparts launched in 2018. All satellite images contain RGB pixels,
with a pixel resolution of 50 cm. The total size of the total area of the dataset is equivalent to 10716.9 square kilometers.
This deep learning project was conducted while I collaborated with the Omdena Team on Seagrass detection challenge in 2024,
I trained on 803 images and their segmentation masks (with split of 80/20%). For this multilabel segmentation task, we trained 4 models,
the basic 4-blocks U-net CNN, VGG16 U-Net, Resnet50 U-net and Efficient Net U-net. Then, I built an ensemble model that achieved a
validation accuracy of about 75% and dice score of about 0.6.
''',
inputs=[gr.Image()],
outputs=[gr.Image()],
examples=sample_images
).launch(debug=True, share=True)
# Launch the interface
iface.launch(share=True)