Spaces:
Sleeping
Sleeping
File size: 4,538 Bytes
e62aad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage.io import imsave
from skimage.filters import threshold_otsu
from doodleverse_utils.prediction_imports import *
from doodleverse_utils.imports import *
#load model
filepath = './saved_model'
model = tf.keras.models.load_model(filepath, compile = True)
model.compile
#segmentation
def segment(input_img, use_tta, use_otsu, dims=(512, 512)):
N = 4
if use_otsu:
print("Use Otsu threshold")
else:
print("No Otsu threshold")
if use_tta:
print("Use TTA")
else:
print("Do not use TTA")
worig, horig, channels = input_img.shape
w, h = dims[0], dims[1]
print("Original dimensions {}x{}".format(worig,horig))
print("New dimensions {}x{}".format(w,h))
img = standardize(input_img)
img = resize(img, dims, preserve_range=True, clip=True)
img = np.expand_dims(img,axis=0)
est_label = model.predict(img)
if use_tta:
#Test Time Augmentation
est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))
#soft voting - sum the softmax scores to return the new TTA estimated softmax scores
est_label = est_label + est_label2 + est_label3 + est_label4
est_label /= 4
pred = np.squeeze(est_label, axis=0)
pred = resize(pred, (worig, horig), preserve_range=True, clip=True)
mask = np.argmax(pred,-1)
imsave("greyscale_download_me.png", mask.astype('uint8'))
class_label_colormap = [
"#3366CC",
"#DC3912",
"#FF9900",
"#109618",
"#990099",
"#0099C6",
"#DD4477",
"#66AA00",
"#B82E2E",
"#316395",
]
# add classes
class_label_colormap = class_label_colormap[:N]
color_label = label_to_colors(
mask,
input_img[:, :, 0] == 0,
alpha=128,
colormap=class_label_colormap,
color_class_offset=0,
do_alpha=False,
)
imsave("color_download_me.png", color_label)
if use_otsu:
c1 = pred[:,:,0]
c2 = pred[:,:,1]
water = c1+c2
water /= water.max()
thres = threshold_otsu(water)
print("Otsu threshold is {}".format(thres))
water_nowater = (water>thres).astype('uint8')
else:
water_nowater = (mask>1).astype('uint8')
#overlay plot
plt.clf()
plt.subplot(121)
plt.imshow(input_img[:,:,-1],cmap='gray')
plt.imshow(color_label, alpha=0.4)
plt.axis("off")
plt.margins(x=0, y=0)
plt.subplot(122)
plt.imshow(input_img[:,:,-1],cmap='gray')
plt.contour(water_nowater, levels=[0], colors='r')
plt.axis("off")
plt.margins(x=0, y=0)
plt.savefig("overlay_download_me.png", dpi=300, bbox_inches="tight")
return color_label, plt , "greyscale_download_me.png", "color_download_me.png", "overlay_download_me.png"
with open("article.html", "r", encoding='utf-8') as f:
article= f.read()
title = "Segment Satellite imagery"
description = "This simple model demonstration segments 15-m Landsat-7/8 or 10-m Sentinel-2 RGB (visible spectrum) imagery into the following classes: 1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), and 4. other (development, bare terrain, vegetated terrain, etc). Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. "
examples= [[l] for l in glob('examples/*.jpg')]
inp = gr.Image()
out1 = gr.Image(type='numpy')
out2 = gr.Plot(type='matplotlib')
out3 = gr.File()
out4 = gr.File()
out5 = gr.File()
inp2 = gr.inputs.Checkbox(default=False, label="Use TTA")
inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu")
Segapp = gr.Interface(segment, [inp, inp2, inp3],
[out1, out2, out3, out4, out5],
title = title, description = description, examples=examples, article=article,
theme="grass")
Segapp.launch(enable_queue=True)
|