pemujo's picture
Update app.py
127662f
import gradio as gr
import tensorflow as tf
from keras.models import load_model
import os
import numpy as np
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
from keras.metrics import MeanIoU
from functools import partial
from glob import glob
from PIL import Image
print(tf.__version__)
def jaccard_coef(y_true, y_pred):
"""
Defines custom jaccard coefficient metric
"""
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
return final_coef_value
def real_dice_coeff(y_true, y_pred):
smooth = 0.0001
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
dice_score = (2.0 * intersection + smooth) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) + smooth)
return dice_score
weights = [0.5,0.5] # hyper parameter
dice_loss = sm.losses.DiceLoss(class_weights = weights)
focal_loss = sm.losses.CategoricalFocalLoss()
TOTAL_LOSS_FACTOR = 5
total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss)
metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU"), sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")]
# model = load_model('../../../fast-disk/w210- ', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
model = load_model('MVP_Trans_Unet_model.keras', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
model.compile(metrics=metrics)
# images_means = {}
# for f in os.listdir('images'):
# if f.endswith('.png'):
# im = Image.open('images/' + f)
# images_means [round(np.asarray(im).mean(),4)] = f.split('.')[0] + '.npy'
images_means = {148.5175: '1205045288117020016.npy',
131.2455: '4617259572479165215.npy',
110.247: '2399738000417381513.npy',
143.0626: '5500775309238786210.npy',
118.2917: '3268948859446517114.npy',
107.8981: '9000307066571621514.npy',
141.0654: '552609781892851211.npy',
127.8189: '5663079497093130113.npy',
152.8617: '3517995218957041214.npy',
139.086: '4182340004986797719.npy',
140.0541: '5846177993231069618.npy',
119.0653: '2526007434389790910.npy'}
# def greet(name):
# return "Hello " + name + "!!"
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()
def predict(ash_image, model=model):
#label = np.load(label_image)
# ash_image = np.load('images/552609781892851211.npy')
im = np.asarray(Image.open(ash_image))
im_mean = round(im.mean(),4)
print(im_mean)
if im_mean in images_means:
im = np.load('numpy_files/' + images_means[im_mean])
y_pred = model.predict(im.reshape(1,256, 256, 3))
prediction = np.argmax(y_pred[0], axis=2).reshape(256,256)
#intersection = label & prediction
#false_negative = label - intersection
#false_possitive = prediction - intersection
#color_prediction = np.stack([false_negative*.7, intersection*.7, false_possitive*.7], axis=2).reshape(256,256,3)
seg_info = [(prediction, 'contrails')]
return(ash_image, seg_info)
if __name__ == "__main__":
class2hexcolor = {"contrails": "#FF0000"}
with gr.Blocks(title="Contrail Predictions") as demo:
gr.Markdown("""<h1><center>Predict Contrails in Satellite Images</center></h1>""")
with gr.Row():
img_input = gr.Image(type="filepath", height=256, width=256, label="Image Input")
img_output = gr.AnnotatedImage(label="Predictions", height=256, width=256, color_map=class2hexcolor)
section_btn = gr.Button("Generate Predictions")
section_btn.click(partial(predict, model=model), img_input, img_output)
images_dir = glob(os.path.join(os.getcwd(), "images") + os.sep + "*.jpg")
examples = [i for i in images_dir]
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
demo.launch()