amosfang's picture
Update app.py
8957603 verified
raw
history blame contribute delete
No virus
7.45 kB
import numpy as np
from PIL import Image
from skimage.transform import resize
import tensorflow as tf
from tensorflow.keras.models import load_model
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import gradio as gr
import os
import io
REPO_ID = "amosfang/segmentation_u_net"
TRAIN_FOLDER = 'train_images'
TEST_FOLDER = 'example_images'
NUM_CLASSES = 7
def pil_image_as_numpy_array(pilimg):
# Convert PIL image to numpy array with Tensorflow utils function
img_array = tf.keras.utils.img_to_array(pilimg)
return img_array
def resize_image(image, input_shape=(224, 224, 3)):
# Convert to NumPy array and normalize
image_array = pil_image_as_numpy_array(image)
image = image_array.astype(np.float32) / 255.
# Resize the image to 224x224
image_resized = resize(image, input_shape, anti_aliasing=True)
return image_resized
def load_model_file(filename):
model_dir = snapshot_download(REPO_ID)
saved_model_filepath = os.path.join(model_dir, filename)
unet_model = load_model(saved_model_filepath)
return unet_model
def get_sample_images(image_folder, format=[('.jpg', '.jpeg'), ('.png')]):
# Initialization
image_files = []
mask_files = []
# Get a list of all files in the folder
img_file_list = os.listdir(image_folder)
img_file_list.sort()
# Filter out only the image files (assuming images have extensions like '.jpg' or '.png')
image_files = [image_folder +'/' + file for file in img_file_list if file.lower().endswith(format[0])]
mask_files = [image_folder +'/' + file for file in img_file_list if file.lower().endswith(format[1])]
if mask_files == []:
print(image_files)
image_files = [[file] for file in image_files]
return image_files
return [list(pair) for pair in zip(image_files, mask_files)]
def ensemble_predict(X_array):
#
# Call the predict methods of the unet_model and the vgg16_unet_model
# to retrieve their predictions.
#
# Sum the two predictions together and return their results.
# You can also consider multiplying a different weight on
# one or both of the models to improve performance
X_array = np.expand_dims(X_array, axis=0)
unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
pred_y_unet = unet_model.predict(X_array)
pred_y_vgg16 = vgg16_model.predict(X_array)
pred_y_resnet50 = resnet50_model.predict(X_array)
return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3
def get_predictions(y_prediction_encoded):
# Convert predictions to categorical indices
predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1
return predicted_label_indices
def predict_on_train(image, mask):
# Get the resized image and mask
sample_image_resized = resize_image(image)
mask_resized = resize_image(mask)
# Create a figure
fig, ax = plt.subplots()
# Display the image
ax.imshow(sample_image_resized)
# Display the image
cax = ax.imshow(mask_resized, alpha=0.5)
# Convert the figure to a PIL Image
image_buffer = io.BytesIO()
plt.savefig(image_buffer, format='png')
image_buffer.seek(0)
mask_pil = Image.open(image_buffer)
# Close the figure to release resources
plt.close(fig)
# ----------------------------------------------
# Steps to get prediction of the satellite image
y_pred = ensemble_predict(sample_image_resized)
y_pred = get_predictions(y_pred).squeeze()
# Define your custom colors for each label
colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
# Create a ListedColormap
cmap = ListedColormap(colors)
# Create a figure
fig, ax = plt.subplots()
# Display the image
ax.imshow(sample_image_resized)
# Display the predictions using the specified colormap
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=NUM_CLASSES, alpha=0.5)
# Create colorbar and set ticks and ticklabels
cbar = plt.colorbar(cax, ticks=np.arange(1, NUM_CLASSES + 1))
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
# Convert the figure to a PIL Image
image_buffer = io.BytesIO()
plt.savefig(image_buffer, format='png')
image_buffer.seek(0)
image_pil = Image.open(image_buffer)
# Close the figure to release resources
plt.close(fig)
return mask_pil, image_pil
def predict_on_test(image):
# Steps to get prediction
sample_image_resized = resize_image(image)
y_pred = ensemble_predict(sample_image_resized)
y_pred = get_predictions(y_pred).squeeze()
# Define your custom colors for each label
colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
# Create a ListedColormap
cmap = ListedColormap(colors)
# Create a figure
fig, ax = plt.subplots()
# Display the image
ax.imshow(sample_image_resized)
# Display the predictions using the specified colormap
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7, alpha=0.5)
# Create colorbar and set ticks and ticklabels
cbar = plt.colorbar(cax, ticks=np.arange(1, 8))
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
# Convert the figure to a PIL Image
image_buffer = io.BytesIO()
plt.savefig(image_buffer, format='png')
image_buffer.seek(0)
image_pil = Image.open(image_buffer)
# Close the figure to release resources
plt.close(fig)
return image_pil
train_images = get_sample_images(TRAIN_FOLDER)
test_images = get_sample_images(TEST_FOLDER)
description= '''
Computer vision deep learning, powered by GPU advancements, plays a potentially
key role in urban planning and climate change research. The U-Net model architecture, that is
commonly used in semantic segmentation, can be applied to automated land cover classification and
seagrass habitat monitoring.
Our project, using the DeepGlobe Land Cover Classification Challenge 2018 dataset, trained four models
(basic U-Net, VGG16 U-Net, Resnet50 U-Net, and Efficient Net U-Net) and achieved a validation accuracy of
approximately 75\% and a dice score of about 0.6 through an ensemble approach on 803 images with
segmentation masks (80/20 split).
'''
# Create the train dataset interface
tab1 = gr.Interface(
fn=predict_on_train,
inputs=[gr.Image(), gr.Image()],
outputs=[gr.Image(label='Ground Truth'), gr.Image(label='Predicted')],
title='Images with Ground Truth',
description=description,
examples=train_images
)
# Create the test dataset interface
tab2 = gr.Interface(
fn=predict_on_test,
inputs=gr.Image(),
outputs=gr.Image(label='Predicted'),
title='Images without Ground Truth',
description=description,
examples=test_images
)
# Create a Multi Interface with Tabs
iface = gr.TabbedInterface([tab1, tab2],
title='Land Cover Segmentation',
tab_names = ['Train','Test'])
# Launch the interface
iface.launch(share=True)