import gradio as gr from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from PIL import Image import tensorflow as tf from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation import requests feature_extractor = SegformerFeatureExtractor.from_pretrained( "nvidia/segformer-b5-finetuned-ade-640-640" ) model = TFSegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b5-finetuned-ade-640-640" ) def ade_palette(): """ADE20K palette that maps each class to RGB values.""" return [ [215, 252, 54], [219, 99, 20], [30, 125, 246], [21, 211, 22], [117, 165, 201], [122, 2, 6], [52, 144, 140], [136, 36, 114], [208, 249, 44], [210, 245, 157], [48, 222, 84], [175, 182, 112], [117, 9, 240], [153, 38, 30], [75, 225, 231], [232, 170, 70], [154, 35, 115], [45, 61, 35], [73, 144, 2], [54, 80, 136], [143, 200, 212], [75, 104, 98], [17, 211, 27], [205, 195, 241], [234, 251, 104], [33, 174, 95], [160, 174, 99], [141, 26, 157], [84, 247, 88], [19, 248, 198], [4, 217, 155], [204, 163, 16], [148, 209, 143], [211, 97, 65], [19, 4, 131], [40, 196, 45], [39, 64, 20], [166, 107, 50], [108, 103, 78], [188, 11, 213], [24, 156, 152], [230, 162, 223], [30, 126, 220], [74, 10, 238], [186, 128, 227], [83, 188, 220], [9, 132, 231], [96, 99, 79], [196, 139, 187], [117, 122, 171], [0, 156, 220], [243, 249, 189], [243, 245, 211], [103, 146, 83], [237, 144, 197], [35, 151, 20], [15, 61, 139], [78, 223, 132], [120, 49, 9], [67, 160, 234], [183, 244, 210], [245, 161, 139], [57, 70, 189], [105, 150, 31], [219, 85, 49], [206, 81, 97], [30, 171, 92], [251, 42, 67], [121, 183, 220], [221, 33, 43], [8, 96, 100], [76, 149, 53], [29, 201, 129], [7, 213, 227], [143, 93, 153], [205, 35, 110], [37, 94, 142], [131, 157, 110], [215, 166, 147], [164, 94, 252], [179, 108, 233], [35, 157, 209], [145, 252, 241], [155, 60, 40], [70, 25, 44], [53, 83, 133], [150, 42, 191], [142, 245, 58], [150, 198, 69], [0, 139, 86], [123, 212, 143], [210, 166, 191], [148, 194, 130], [35, 213, 154], [203, 139, 93], [59, 86, 45], [9, 50, 169], [207, 118, 246], [200, 82, 65], [37, 75, 120], [237, 99, 63], [168, 145, 190], [225, 48, 16], [17, 184, 115], [224, 124, 15], [148, 167, 47], [162, 25, 116], [154, 90, 36], [185, 247, 43], [183, 138, 202], [64, 96, 117], [187, 140, 140], [121, 116, 188], [252, 251, 162], [85, 50, 40], [209, 241, 228], [30, 41, 95], [246, 217, 64], [151, 149, 197], [117, 42, 205], [26, 248, 30], [28, 224, 232], [228, 89, 96], [198, 44, 113], [220, 68, 218], [59, 85, 210], [24, 230, 191], [145, 192, 181], [132, 189, 92], [47, 29, 128], [11, 245, 204], [182, 79, 207], [42, 64, 187], [72, 4, 37], [105, 67, 133], [86, 27, 200], [243, 211, 40], [150, 136, 40], [3, 192, 172], [34, 96, 149], [32, 108, 56], [128, 10, 137], [94, 211, 108], [78, 250, 243], [6, 74, 205], [6, 7, 38], [161, 26, 40], [145, 254, 27], [119, 145, 127], [13, 82, 153], ] labels_list = [] with open(r'labels.txt', 'r') as fp: for line in fp: labels_list.append(line[:-1]) colormap = np.asarray(ade_palette()) def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") if np.max(label) >= len(colormap): raise ValueError("label value too large.") return colormap[label] def draw_plot(pred_img, seg): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis('off') LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg.numpy().astype("uint8")) ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def sepia(input_img): input_img = Image.fromarray(input_img) inputs = feature_extractor(images=input_img, return_tensors="tf") outputs = model(**inputs) logits = outputs.logits logits = tf.transpose(logits, [0, 2, 3, 1]) logits = tf.image.resize( logits, input_img.size[::-1] ) # We reverse the shape of `image` because `image.size` returns width and height. seg = tf.math.argmax(logits, axis=-1)[0] color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 ) # height, width, 3 for label, color in enumerate(colormap): color_seg[seg.numpy() == label, :] = color # Show image + mask pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 pred_img = pred_img.astype(np.uint8) fig = draw_plot(pred_img, seg) return fig demo = gr.Interface(fn=sepia, inputs=gr.Image(), outputs=['plot'], examples=["image-1.jpg", "image-2.jpg", "image-3.jpg", "image-4.jpeg", "image-5.jpg"], allow_flagging='never') demo.launch()