import gradio as gr import numpy as np import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' from glob import glob import tensorflow as tf import matplotlib.pyplot as plt from skimage.transform import resize from skimage.io import imsave from skimage.filters import threshold_otsu from doodleverse_utils.prediction_imports import * from doodleverse_utils.imports import * #load model filepath = './saved_model' model = tf.keras.models.load_model(filepath, compile = True) model.compile #segmentation def segment(input_img, use_tta, use_otsu, dims=(512, 512)): N = 2 if use_otsu: print("Use Otsu threshold") else: print("No Otsu threshold") if use_tta: print("Use TTA") else: print("Do not use TTA") worig, horig, channels = input_img.shape w, h = dims[0], dims[1] print("Original dimensions {}x{}".format(worig,horig)) print("New dimensions {}x{}".format(w,h)) img = standardize(input_img) img = resize(img, dims, preserve_range=True, clip=True) img = np.expand_dims(img,axis=0) est_label = model.predict(img) if use_tta: #Test Time Augmentation est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1)) est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1)) est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img)))))) #soft voting - sum the softmax scores to return the new TTA estimated softmax scores est_label = est_label + est_label2 + est_label3 + est_label4 est_label /= 4 pred = np.squeeze(est_label, axis=0) pred = resize(pred, (worig, horig), preserve_range=True, clip=True) mask = np.argmax(pred,-1) imsave("greyscale_download_me.png", mask.astype('uint8')) class_label_colormap = [ "#3366CC", "#DC3912", "#FF9900", "#109618", "#990099", "#0099C6", "#DD4477", "#66AA00", "#B82E2E", "#316395", ] # add classes class_label_colormap = class_label_colormap[:N] color_label = label_to_colors( mask, input_img[:, :, 0] == 0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False, ) imsave("color_download_me.png", color_label) if use_otsu: thres = threshold_otsu(mask) print("Otsu threshold is {}".format(thres)) water_nowater = (mask>thres).astype('uint8') else: water_nowater = (mask>=1).astype('uint8') #overlay plot plt.clf() plt.subplot(121) plt.imshow(input_img[:,:,-1],cmap='gray') plt.imshow(color_label, alpha=0.4) plt.axis("off") plt.margins(x=0, y=0) plt.subplot(122) plt.imshow(input_img[:,:,-1],cmap='gray') plt.contour(water_nowater, levels=[0], colors='r') plt.axis("off") plt.margins(x=0, y=0) plt.savefig("overlay_download_me.png", dpi=300, bbox_inches="tight") return color_label, plt , "greyscale_download_me.png", "color_download_me.png", "overlay_download_me.png" with open("article.html", "r", encoding='utf-8') as f: article= f.read() title = "Segment Satellite imagery" description = "This simple model demonstration segments 15-m Landsat-7/8 or 10-m Sentinel-2 RGB (visible spectrum) imagery into the following classes: 1. water and 2. other. Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. " examples= [[l] for l in glob('examples/*.jpg')] inp = gr.Image() out1 = gr.Image(type='numpy') out2 = gr.Plot(type='matplotlib') out3 = gr.File() out4 = gr.File() out5 = gr.File() inp2 = gr.inputs.Checkbox(default=False, label="Use TTA") inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu") Segapp = gr.Interface(segment, [inp, inp2, inp3], [out1, out2, out3, out4, out5], title = title, description = description, examples=examples, article=article, theme="grass") Segapp.launch(enable_queue=True)