import gradio as gr import numpy as np import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' from glob import glob import tensorflow as tf import matplotlib.pyplot as plt from skimage.transform import resize from skimage.io import imsave from skimage.filters import threshold_otsu from doodleverse_utils.prediction_imports import * from doodleverse_utils.imports import * #load model filepath = './saved_model' model = tf.keras.models.load_model(filepath, compile = True) model.compile #segmentation def segment(input_img, use_tta, use_otsu, dims=(512, 512)): N = 4 if use_otsu: print("Use Otsu threshold") else: print("No Otsu threshold") if use_tta: print("Use TTA") else: print("Do not use TTA") worig, horig, channels = input_img.shape w, h = dims[0], dims[1] print("Original dimensions {}x{}".format(worig,horig)) print("New dimensions {}x{}".format(w,h)) img = standardize(input_img) img = resize(img, dims, preserve_range=True, clip=True) img = np.expand_dims(img,axis=0) est_label = model.predict(img) if use_tta: #Test Time Augmentation est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1)) est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1)) est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img)))))) #soft voting - sum the softmax scores to return the new TTA estimated softmax scores est_label = est_label + est_label2 + est_label3 + est_label4 est_label /= 4 pred = np.squeeze(est_label, axis=0) pred = resize(pred, (worig, horig), preserve_range=True, clip=True) mask = np.argmax(pred,-1) imsave("greyscale_download_me.png", mask.astype('uint8')) class_label_colormap = [ "#3366CC", "#DC3912", "#FF9900", "#109618", "#990099", "#0099C6", "#DD4477", "#66AA00", "#B82E2E", "#316395", ] # add classes class_label_colormap = class_label_colormap[:N] color_label = label_to_colors( mask, input_img[:, :, 0] == 0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False, ) imsave("color_download_me.png", color_label) if use_otsu: c1 = pred[:,:,0] c2 = pred[:,:,1] water = c1+c2 water /= water.max() thres = threshold_otsu(water) print("Otsu threshold is {}".format(thres)) water_nowater = (water>thres).astype('uint8') else: water_nowater = (mask>1).astype('uint8') #overlay plot plt.clf() plt.subplot(121) plt.imshow(input_img[:,:,-1],cmap='gray') plt.imshow(color_label, alpha=0.4) plt.axis("off") plt.margins(x=0, y=0) plt.subplot(122) plt.imshow(input_img[:,:,-1],cmap='gray') plt.contour(water_nowater, levels=[0], colors='r') plt.axis("off") plt.margins(x=0, y=0) plt.savefig("overlay_download_me.png", dpi=300, bbox_inches="tight") return color_label, plt , "greyscale_download_me.png", "color_download_me.png", "overlay_download_me.png" with open("article.html", "r", encoding='utf-8') as f: article= f.read() title = "Segment Satellite imagery" description = "This simple model demonstration segments 15-m Landsat-7/8 or 10-m Sentinel-2 RGB (visible spectrum) imagery into the following classes: 1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), and 4. other (development, bare terrain, vegetated terrain, etc). Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. " examples= [[l] for l in glob('examples/*.jpg')] inp = gr.Image() out1 = gr.Image(type='numpy') out2 = gr.Plot(type='matplotlib') out3 = gr.File() out4 = gr.File() out5 = gr.File() inp2 = gr.inputs.Checkbox(default=False, label="Use TTA") inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu") Segapp = gr.Interface(segment, [inp, inp2, inp3], [out1, out2, out3, out4, out5], title = title, description = description, examples=examples, article=article, theme="grass") Segapp.launch(enable_queue=True)