Chancee12's picture
Update app.py
964f4ce
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import random
from keras.utils import get_custom_objects
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import segmentation_models as sm
from keras import backend as K
from keras.models import load_model
class_building = '#3C1098'
class_building = class_building.lstrip('#')
class_building = np.array(tuple(int(class_building[i:i+2], 16) for i in (0,2,4)))
class_land = '#8429F6'
class_land = class_land.lstrip('#')
class_land = np.array(tuple(int(class_land[i:i+2], 16) for i in (0,2,4)))
class_road = '#6EC1E4'
class_road = class_road.lstrip('#')
class_road = np.array(tuple(int(class_road[i:i+2], 16) for i in (0,2,4)))
class_vegetation = '#FEDD3A'
class_vegetation = class_vegetation.lstrip('#')
class_vegetation = np.array(tuple(int(class_vegetation[i:i+2], 16) for i in (0,2,4)))
class_water = '#E2A929'
class_water = class_water.lstrip('#')
class_water = np.array(tuple(int(class_water[i:i+2], 16) for i in (0,2,4)))
class_unlabeled = '#9B9B9B'
class_unlabeled = class_unlabeled.lstrip('#')
class_unlabeled = np.array(tuple(int(class_unlabeled[i:i+2], 16) for i in (0,2,4)))
def jaccard_coef(y_true, y_pred):
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
final_coef_value = (intersection + 1.0) / (
K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
return final_coef_value
# six class for six weights
weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
dice_loss = sm.losses.DiceLoss(class_weights=weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
satellite_model = load_model('satellite_segmentation_full_v2.h5',
custom_objects=({'dice_loss_plus_1focal_loss': total_loss, 'jaccard_coef': jaccard_coef}))
def label_to_rgb(label_segment):
rgb_image = np.zeros((label_segment.shape[0], label_segment.shape[1], 3), dtype=np.uint8)
rgb_image[label_segment == 0] = class_water
rgb_image[label_segment == 1] = class_land
rgb_image[label_segment == 2] = class_road
rgb_image[label_segment == 3] = class_building
rgb_image[label_segment == 4] = class_vegetation
rgb_image[label_segment == 5] = class_unlabeled
return rgb_image
def process_input_image(image_source):
image = np.expand_dims(image_source, 0)
prediction = satellite_model.predict(image)
predicted_image = np.argmax(prediction, axis=3)
predicted_image = predicted_image[0, :, :]
# Convert the predicted image labels to RGB
colored_predicted_image = label_to_rgb(predicted_image)
return "Predicted Masked Image", colored_predicted_image
my_app = gr.Blocks()
with my_app:
gr.Markdown("Image Processing Application UI with Gradio")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Row():
with gr.Column():
img_source = gr.Image(label="Please select source Image", shape=(256, 256))
source_image_loader = gr.Button("Load above Image")
with gr.Column():
output_label = gr.Label(label="Image Info")
img_output = gr.Image(label="Image Output")
source_image_loader.click(
process_input_image,
[
img_source
],
[
output_label,
img_output
]
)
my_app.launch(debug=True)