Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
from keras.models import load_model | |
import os | |
import numpy as np | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
os.environ["SM_FRAMEWORK"] = "tf.keras" | |
import segmentation_models as sm | |
from keras.metrics import MeanIoU | |
from functools import partial | |
from glob import glob | |
from PIL import Image | |
print(tf.__version__) | |
def jaccard_coef(y_true, y_pred): | |
""" | |
Defines custom jaccard coefficient metric | |
""" | |
y_true_flatten = K.flatten(y_true) | |
y_pred_flatten = K.flatten(y_pred) | |
intersection = K.sum(y_true_flatten * y_pred_flatten) | |
final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0) | |
return final_coef_value | |
def real_dice_coeff(y_true, y_pred): | |
smooth = 0.0001 | |
y_true_flatten = K.flatten(y_true) | |
y_pred_flatten = K.flatten(y_pred) | |
intersection = K.sum(y_true_flatten * y_pred_flatten) | |
dice_score = (2.0 * intersection + smooth) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) + smooth) | |
return dice_score | |
weights = [0.5,0.5] # hyper parameter | |
dice_loss = sm.losses.DiceLoss(class_weights = weights) | |
focal_loss = sm.losses.CategoricalFocalLoss() | |
TOTAL_LOSS_FACTOR = 5 | |
total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss) | |
metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU"), sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")] | |
# model = load_model('../../../fast-disk/w210- ', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False) | |
model = load_model('MVP_Trans_Unet_model.keras', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False) | |
model.compile(metrics=metrics) | |
# images_means = {} | |
# for f in os.listdir('images'): | |
# if f.endswith('.png'): | |
# im = Image.open('images/' + f) | |
# images_means [round(np.asarray(im).mean(),4)] = f.split('.')[0] + '.npy' | |
images_means = {148.5175: '1205045288117020016.npy', | |
131.2455: '4617259572479165215.npy', | |
110.247: '2399738000417381513.npy', | |
143.0626: '5500775309238786210.npy', | |
118.2917: '3268948859446517114.npy', | |
107.8981: '9000307066571621514.npy', | |
141.0654: '552609781892851211.npy', | |
127.8189: '5663079497093130113.npy', | |
152.8617: '3517995218957041214.npy', | |
139.086: '4182340004986797719.npy', | |
140.0541: '5846177993231069618.npy', | |
119.0653: '2526007434389790910.npy'} | |
# def greet(name): | |
# return "Hello " + name + "!!" | |
# iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
# iface.launch() | |
def predict(ash_image, model=model): | |
#label = np.load(label_image) | |
# ash_image = np.load('images/552609781892851211.npy') | |
im = np.asarray(Image.open(ash_image)) | |
im_mean = round(im.mean(),4) | |
print(im_mean) | |
if im_mean in images_means: | |
im = np.load('numpy_files/' + images_means[im_mean]) | |
y_pred = model.predict(im.reshape(1,256, 256, 3)) | |
prediction = np.argmax(y_pred[0], axis=2).reshape(256,256) | |
#intersection = label & prediction | |
#false_negative = label - intersection | |
#false_possitive = prediction - intersection | |
#color_prediction = np.stack([false_negative*.7, intersection*.7, false_possitive*.7], axis=2).reshape(256,256,3) | |
seg_info = [(prediction, 'contrails')] | |
return(ash_image, seg_info) | |
if __name__ == "__main__": | |
class2hexcolor = {"contrails": "#FF0000"} | |
with gr.Blocks(title="Contrail Predictions") as demo: | |
gr.Markdown("""<h1><center>Predict Contrails in Satellite Images</center></h1>""") | |
with gr.Row(): | |
img_input = gr.Image(type="filepath", height=256, width=256, label="Image Input") | |
img_output = gr.AnnotatedImage(label="Predictions", height=256, width=256, color_map=class2hexcolor) | |
section_btn = gr.Button("Generate Predictions") | |
section_btn.click(partial(predict, model=model), img_input, img_output) | |
images_dir = glob(os.path.join(os.getcwd(), "images") + os.sep + "*.jpg") | |
examples = [i for i in images_dir] | |
gr.Examples(examples=examples, inputs=img_input, outputs=img_output) | |
demo.launch() | |