Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| from keras.models import load_model | |
| import os | |
| import numpy as np | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| os.environ["SM_FRAMEWORK"] = "tf.keras" | |
| import segmentation_models as sm | |
| from keras.metrics import MeanIoU | |
| from functools import partial | |
| from glob import glob | |
| from PIL import Image | |
| print(tf.__version__) | |
| def jaccard_coef(y_true, y_pred): | |
| """ | |
| Defines custom jaccard coefficient metric | |
| """ | |
| y_true_flatten = K.flatten(y_true) | |
| y_pred_flatten = K.flatten(y_pred) | |
| intersection = K.sum(y_true_flatten * y_pred_flatten) | |
| final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0) | |
| return final_coef_value | |
| def real_dice_coeff(y_true, y_pred): | |
| smooth = 0.0001 | |
| y_true_flatten = K.flatten(y_true) | |
| y_pred_flatten = K.flatten(y_pred) | |
| intersection = K.sum(y_true_flatten * y_pred_flatten) | |
| dice_score = (2.0 * intersection + smooth) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) + smooth) | |
| return dice_score | |
| weights = [0.5,0.5] # hyper parameter | |
| dice_loss = sm.losses.DiceLoss(class_weights = weights) | |
| focal_loss = sm.losses.CategoricalFocalLoss() | |
| TOTAL_LOSS_FACTOR = 5 | |
| total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss) | |
| metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU"), sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")] | |
| # model = load_model('../../../fast-disk/w210- ', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False) | |
| model = load_model('MVP_Trans_Unet_model.keras', custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False) | |
| model.compile(metrics=metrics) | |
| # images_means = {} | |
| # for f in os.listdir('images'): | |
| # if f.endswith('.png'): | |
| # im = Image.open('images/' + f) | |
| # images_means [round(np.asarray(im).mean(),4)] = f.split('.')[0] + '.npy' | |
| images_means = {148.5175: '1205045288117020016.npy', | |
| 131.2455: '4617259572479165215.npy', | |
| 110.247: '2399738000417381513.npy', | |
| 143.0626: '5500775309238786210.npy', | |
| 118.2917: '3268948859446517114.npy', | |
| 107.8981: '9000307066571621514.npy', | |
| 141.0654: '552609781892851211.npy', | |
| 127.8189: '5663079497093130113.npy', | |
| 152.8617: '3517995218957041214.npy', | |
| 139.086: '4182340004986797719.npy', | |
| 140.0541: '5846177993231069618.npy', | |
| 119.0653: '2526007434389790910.npy'} | |
| # def greet(name): | |
| # return "Hello " + name + "!!" | |
| # iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
| # iface.launch() | |
| def predict(ash_image, model=model): | |
| #label = np.load(label_image) | |
| # ash_image = np.load('images/552609781892851211.npy') | |
| im = np.asarray(Image.open(ash_image)) | |
| im_mean = round(im.mean(),4) | |
| print(im_mean) | |
| if im_mean in images_means: | |
| im = np.load('numpy_files/' + images_means[im_mean]) | |
| y_pred = model.predict(im.reshape(1,256, 256, 3)) | |
| prediction = np.argmax(y_pred[0], axis=2).reshape(256,256) | |
| #intersection = label & prediction | |
| #false_negative = label - intersection | |
| #false_possitive = prediction - intersection | |
| #color_prediction = np.stack([false_negative*.7, intersection*.7, false_possitive*.7], axis=2).reshape(256,256,3) | |
| seg_info = [(prediction, 'contrails')] | |
| return(ash_image, seg_info) | |
| if __name__ == "__main__": | |
| class2hexcolor = {"contrails": "#FF0000"} | |
| with gr.Blocks(title="Contrail Predictions") as demo: | |
| gr.Markdown("""<h1><center>Predict Contrails in Satellite Images</center></h1>""") | |
| with gr.Row(): | |
| img_input = gr.Image(type="filepath", height=256, width=256, label="Image Input") | |
| img_output = gr.AnnotatedImage(label="Predictions", height=256, width=256, color_map=class2hexcolor) | |
| section_btn = gr.Button("Generate Predictions") | |
| section_btn.click(partial(predict, model=model), img_input, img_output) | |
| images_dir = glob(os.path.join(os.getcwd(), "images") + os.sep + "*.jpg") | |
| examples = [i for i in images_dir] | |
| gr.Examples(examples=examples, inputs=img_input, outputs=img_output) | |
| demo.launch() | |