File size: 3,328 Bytes
1899d85
8c16ebc
d905150
 
 
 
 
ea62197
8c16ebc
75d8751
d905150
 
 
 
 
 
7aeba2e
 
 
 
 
 
d905150
75d8751
d905150
 
 
 
75d8751
 
d905150
 
 
 
 
 
 
 
 
75d8751
 
d905150
 
 
 
3b78676
75d8751
 
3b78676
 
 
d905150
75d8751
d905150
d22dc1f
ea62197
d905150
 
67680bf
 
7e182e3
8a7486c
8b7ee47
 
 
 
d905150
 
 
75d8751
d905150
5a36cdc
7058c31
67680bf
ea62197
 
c25a556
 
 
d905150
2302330
d905150
 
 
c48b854
5a36cdc
d905150
0e5a4ff
6577280
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, io
import cv2
import gradio as gr
import tensorflow as tf
import numpy as np
import keras.backend as K

from matplotlib import pyplot as plt
from PIL import Image
from tensorflow import keras


resized_shape = (768, 768, 3)
IMG_SCALING = (1, 1)


# # Download the model file
# def download_model():
#     url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk"
#     output = "seg_unet_model.h5"
#     gdown.download(url, output, quiet=False)
#     return output

model_file = "./seg_unet_model.h5"

#Custom objects for model

def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
    targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
    inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
    intersection = K.sum(targets * inputs)
    dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    inputs = K.clip(inputs, eps, 1.0 - eps)
    out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
    weighted_ce = K.mean(out, axis=-1)
    combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
    return combo

def dice_coef(y_true, y_pred, smooth=1):
    y_pred = tf.dtypes.cast(y_pred, tf.int32)
    y_true = tf.dtypes.cast(y_true, tf.int32)
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])                     
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])           
    return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)

def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
    focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

# Load the model
seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})

# inputs = gr.inputs.Image(type="pil", label="Upload an image")
# image_output = gr.outputs.Image(type="pil", label="Output Image")
# outputs = gr.outputs.HTML() #uncomment for single class output 

rows = 1
columns = 1

def gen_pred(img, model=seg_model):
    pil_image = img.convert('RGB')
    open_cv_image = np.array(pil_image)
    img = open_cv_image[:, :, ::-1].copy() 
    # img = cv2.imread("./003e2c95d.jpg")
    img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img/255
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = np.squeeze(pred, axis=0)
    fig = plt.figure(figsize=(3, 3))
    fig.add_subplot(rows, columns, 1)
    # plt.imshow(pred, interpolation='catrom')
    plt.imshow(pred)
    plt.axis('off')
    plt.show()
    return fig

title = "<h1 style='text-align: center;'>Semantic Segmentation (Airbus Ship Detection Challenge)</h1>"
description = "Upload an image and get prediction mask"

gr.Interface(fn=gen_pred, 
             inputs=[gr.components.Image(type='pil')], 
             outputs=["plot"], 
             title=title, 
             examples=[["00c3db267.jpg"], ["00dc34840.jpg"], ["00371aa92.jpg"]],
             description=description).launch()