jgyegyiri2023's picture
Update app.py
d3759bf
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import gradio as gr
def jaccard_coef(y_true, y_pred):
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
return final_coef_value
weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
dice_loss = sm.losses.DiceLoss(class_weights = weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
satellite_model = load_model('model/model_checkpoint.h5',custom_objects=({'dice_loss_plus_1focal_loss': total_loss, 'jaccard_coef': jaccard_coef}))
def process_input_image(image_source):
image = np.expand_dims(image_source, 0)
prediction = satellite_model.predict(image)
predicted_image = np.argmax(prediction, axis=3)
predicted_image = predicted_image[0,:,:]
predicted_image = predicted_image * 50
return 'Predicted Masked Image', predicted_image
my_app = gr.Blocks()
with my_app:
gr.Markdown("Statellite Image Segmentation Application UI with Gradio, load your satellite image and follow the instructions")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Row():
with gr.Column():
img_source = gr.Image(label="Please select source Image", shape=(256, 256))
source_image_loader = gr.Button("Load above Image")
with gr.Column():
output_label = gr.Label(label="Image Info")
img_output = gr.Image(label="Image Output")
source_image_loader.click(
process_input_image,
[
img_source
],
[
output_label,
img_output
]
)
my_app.launch(debug=True)