File size: 4,538 Bytes
e62aad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from glob import glob 

import tensorflow as tf
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage.io import imsave
from skimage.filters import threshold_otsu

from doodleverse_utils.prediction_imports import *
from doodleverse_utils.imports import *


#load model
filepath = './saved_model'
model = tf.keras.models.load_model(filepath, compile = True)
model.compile

#segmentation
def segment(input_img, use_tta, use_otsu, dims=(512, 512)):

    N = 4

    if use_otsu:
        print("Use Otsu threshold")
    else:
        print("No Otsu threshold")

    if use_tta:
        print("Use TTA")
    else:
        print("Do not use TTA")


    worig, horig, channels = input_img.shape

    w, h = dims[0], dims[1]

    print("Original dimensions {}x{}".format(worig,horig))
    print("New dimensions {}x{}".format(w,h))

    img = standardize(input_img)
    
    img = resize(img, dims, preserve_range=True, clip=True) 
    
    img = np.expand_dims(img,axis=0)
    
    est_label = model.predict(img)

    if use_tta:
        #Test Time Augmentation
        est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
        est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
        est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))

        #soft voting - sum the softmax scores to return the new TTA estimated softmax scores
        est_label = est_label + est_label2 + est_label3 + est_label4
        est_label /= 4
    
    pred = np.squeeze(est_label, axis=0)
    pred = resize(pred, (worig, horig), preserve_range=True, clip=True)
    
    mask = np.argmax(pred,-1)

    imsave("greyscale_download_me.png", mask.astype('uint8'))
    
    class_label_colormap = [
        "#3366CC",
        "#DC3912",
        "#FF9900",
        "#109618",
        "#990099",
        "#0099C6",
        "#DD4477",
        "#66AA00",
        "#B82E2E",
        "#316395",
    ]
    
    # add classes
    class_label_colormap = class_label_colormap[:N]

    color_label = label_to_colors(
        mask,
        input_img[:, :, 0] == 0,
        alpha=128,
        colormap=class_label_colormap,
        color_class_offset=0,
        do_alpha=False,
    )
    
    imsave("color_download_me.png", color_label)


    if use_otsu:
        c1 = pred[:,:,0]
        c2 = pred[:,:,1]
        water = c1+c2
        water /= water.max()
        thres = threshold_otsu(water)
        print("Otsu threshold is {}".format(thres))
        water_nowater = (water>thres).astype('uint8')
    else:
        water_nowater = (mask>1).astype('uint8')


    #overlay plot
    plt.clf()
    plt.subplot(121)
    plt.imshow(input_img[:,:,-1],cmap='gray')
    plt.imshow(color_label, alpha=0.4)
    plt.axis("off")
    plt.margins(x=0, y=0)

    plt.subplot(122)
    plt.imshow(input_img[:,:,-1],cmap='gray')
    plt.contour(water_nowater, levels=[0], colors='r')
    plt.axis("off")
    plt.margins(x=0, y=0)

    plt.savefig("overlay_download_me.png", dpi=300, bbox_inches="tight")    

    return color_label, plt , "greyscale_download_me.png", "color_download_me.png", "overlay_download_me.png"



with open("article.html", "r", encoding='utf-8') as f:
    article= f.read()

title = "Segment Satellite imagery" 
description = "This simple model demonstration segments 15-m Landsat-7/8 or 10-m Sentinel-2 RGB (visible spectrum) imagery into the following classes: 1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), and 4. other (development, bare terrain, vegetated terrain, etc). Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. " 

examples= [[l] for l in glob('examples/*.jpg')]

inp = gr.Image()
out1 = gr.Image(type='numpy')
out2 = gr.Plot(type='matplotlib')
out3 = gr.File()
out4 = gr.File()
out5 = gr.File()

inp2 = gr.inputs.Checkbox(default=False, label="Use TTA")
inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu")

Segapp = gr.Interface(segment, [inp, inp2, inp3], 
                    [out1, out2, out3, out4, out5], 
                    title = title, description = description, examples=examples,  article=article,
                    theme="grass")
                    
Segapp.launch(enable_queue=True)