dbuscombe's picture
v1
e62aad4
import gradio as gr
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage.io import imsave
from skimage.filters import threshold_otsu
from doodleverse_utils.prediction_imports import *
from doodleverse_utils.imports import *
#load model
filepath = './saved_model'
model = tf.keras.models.load_model(filepath, compile = True)
model.compile
#segmentation
def segment(input_img, use_tta, use_otsu, dims=(512, 512)):
N = 4
if use_otsu:
print("Use Otsu threshold")
else:
print("No Otsu threshold")
if use_tta:
print("Use TTA")
else:
print("Do not use TTA")
worig, horig, channels = input_img.shape
w, h = dims[0], dims[1]
print("Original dimensions {}x{}".format(worig,horig))
print("New dimensions {}x{}".format(w,h))
img = standardize(input_img)
img = resize(img, dims, preserve_range=True, clip=True)
img = np.expand_dims(img,axis=0)
est_label = model.predict(img)
if use_tta:
#Test Time Augmentation
est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))
#soft voting - sum the softmax scores to return the new TTA estimated softmax scores
est_label = est_label + est_label2 + est_label3 + est_label4
est_label /= 4
pred = np.squeeze(est_label, axis=0)
pred = resize(pred, (worig, horig), preserve_range=True, clip=True)
mask = np.argmax(pred,-1)
imsave("greyscale_download_me.png", mask.astype('uint8'))
class_label_colormap = [
"#3366CC",
"#DC3912",
"#FF9900",
"#109618",
"#990099",
"#0099C6",
"#DD4477",
"#66AA00",
"#B82E2E",
"#316395",
]
# add classes
class_label_colormap = class_label_colormap[:N]
color_label = label_to_colors(
mask,
input_img[:, :, 0] == 0,
alpha=128,
colormap=class_label_colormap,
color_class_offset=0,
do_alpha=False,
)
imsave("color_download_me.png", color_label)
if use_otsu:
c1 = pred[:,:,0]
c2 = pred[:,:,1]
water = c1+c2
water /= water.max()
thres = threshold_otsu(water)
print("Otsu threshold is {}".format(thres))
water_nowater = (water>thres).astype('uint8')
else:
water_nowater = (mask>1).astype('uint8')
#overlay plot
plt.clf()
plt.subplot(121)
plt.imshow(input_img[:,:,-1],cmap='gray')
plt.imshow(color_label, alpha=0.4)
plt.axis("off")
plt.margins(x=0, y=0)
plt.subplot(122)
plt.imshow(input_img[:,:,-1],cmap='gray')
plt.contour(water_nowater, levels=[0], colors='r')
plt.axis("off")
plt.margins(x=0, y=0)
plt.savefig("overlay_download_me.png", dpi=300, bbox_inches="tight")
return color_label, plt , "greyscale_download_me.png", "color_download_me.png", "overlay_download_me.png"
with open("article.html", "r", encoding='utf-8') as f:
article= f.read()
title = "Segment Satellite imagery"
description = "This simple model demonstration segments 15-m Landsat-7/8 or 10-m Sentinel-2 RGB (visible spectrum) imagery into the following classes: 1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), and 4. other (development, bare terrain, vegetated terrain, etc). Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. "
examples= [[l] for l in glob('examples/*.jpg')]
inp = gr.Image()
out1 = gr.Image(type='numpy')
out2 = gr.Plot(type='matplotlib')
out3 = gr.File()
out4 = gr.File()
out5 = gr.File()
inp2 = gr.inputs.Checkbox(default=False, label="Use TTA")
inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu")
Segapp = gr.Interface(segment, [inp, inp2, inp3],
[out1, out2, out3, out4, out5],
title = title, description = description, examples=examples, article=article,
theme="grass")
Segapp.launch(enable_queue=True)