import os, io import cv2 import gradio as gr import tensorflow as tf import numpy as np import keras.backend as K from matplotlib import pyplot as plt from PIL import Image from tensorflow import keras resized_shape = (768, 768, 3) IMG_SCALING = (1, 1) # # Download the model file # def download_model(): # url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk" # output = "seg_unet_model.h5" # gdown.download(url, output, quiet=False) # return output model_file = "./seg_unet_model.h5" #Custom objects for model def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1): targets = tf.dtypes.cast(K.flatten(y_true), tf.float32) inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32) intersection = K.sum(targets * inputs) dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth) inputs = K.clip(inputs, eps, 1.0 - eps) out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs)))) weighted_ce = K.mean(out, axis=-1) combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice) return combo def dice_coef(y_true, y_pred, smooth=1): y_pred = tf.dtypes.cast(y_pred, tf.int32) y_true = tf.dtypes.cast(y_true, tf.int32) intersection = K.sum(y_true * y_pred, axis=[1,2,3]) union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) return K.mean((2 * intersection + smooth) / (union + smooth), axis=0) def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25): pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon())) return focal_loss_fixed # Load the model seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef}) # inputs = gr.inputs.Image(type="pil", label="Upload an image") # image_output = gr.outputs.Image(type="pil", label="Output Image") # outputs = gr.outputs.HTML() #uncomment for single class output rows = 1 columns = 1 def gen_pred(img, model=seg_model): pil_image = img.convert('RGB') open_cv_image = np.array(pil_image) img = open_cv_image[:, :, ::-1].copy() # img = cv2.imread("./003e2c95d.jpg") img = img[::IMG_SCALING[0], ::IMG_SCALING[1]] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img/255 img = tf.expand_dims(img, axis=0) pred = model.predict(img) pred = np.squeeze(pred, axis=0) fig = plt.figure(figsize=(3, 3)) fig.add_subplot(rows, columns, 1) # plt.imshow(pred, interpolation='catrom') plt.imshow(pred) plt.axis('off') plt.show() return fig title = "

Semantic Segmentation (Airbus Ship Detection Challenge)

" description = "Upload an image and get prediction mask" gr.Interface(fn=gen_pred, inputs=[gr.Image(type='pil')], outputs=["plot"], title=title, examples=[["00c3db267.jpg"], ["00dc34840.jpg"], ["00371aa92.jpg"]], description=description, enable_queue=True).launch()