import gradio as gr import pandas as pd from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from PIL import Image import tensorflow as tf from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation feature_extractor = SegformerFeatureExtractor.from_pretrained( "nvidia/segformer-b5-finetuned-ade-640-640" ) model = TFSegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b5-finetuned-ade-640-640" ) def ade_palette(): """ADE20K palette that maps each class to RGB values.""" return [ [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255], ] labels_list = [ 'wall', 'building;edifice', 'sky', 'floor;flooring', 'tree', 'ceiling', 'road;route', 'bed', 'windowpane;window', 'grass', 'cabinet', 'sidewalk;pavement', 'person;individual;someone;somebody;mortal;soul', 'earth;ground', 'door;double;door', 'table', 'mountain;mount', 'plant;flora;plant;life', 'curtain;drape;drapery;mantle;pall', 'chair', 'car;auto;automobile;machine;motorcar', 'water', 'painting;picture', 'sofa;couch;lounge', 'shelf', 'house', 'sea', 'mirror', 'rug;carpet;carpeting', 'field', 'armchair', 'seat', 'fence;fencing', 'desk', 'rock;stone', 'wardrobe;closet;press', 'lamp', 'bathtub;bathing;tub;bath;tub', 'railing;rail', 'cushion', 'base;pedestal;stand', 'box', 'column;pillar', 'signboard;sign', 'chest;of;drawers;chest;bureau;dresser', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace;hearth;open;fireplace', 'refrigerator;icebox', 'grandstand;covered;stand', 'path', 'stairs;steps', 'runway', 'case;display;case;showcase;vitrine', 'pool;table;billiard;table;snooker;table', 'pillow', 'screen;door;screen', 'stairway;staircase', 'river', 'bridge;span', 'bookcase', 'blind;screen', 'coffee;table;cocktail;table', 'toilet;can;commode;crapper;pot;potty;stool;throne', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove;kitchen;stove;range;kitchen;range;cooking;stove', 'palm;palm;tree', 'kitchen;island', 'computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system', 'swivel;chair', 'boat', 'bar', 'arcade;machine', 'hovel;hut;hutch;shack;shanty', 'bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle', 'towel', 'light;light;source', 'truck;motortruck', 'tower', 'chandelier;pendant;pendent', 'awning;sunshade;sunblind', 'streetlight;street;lamp', 'booth;cubicle;stall;kiosk', 'television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box', 'airplane;aeroplane;plane', 'dirt;track', 'apparel;wearing;apparel;dress;clothes', 'pole', 'land;ground;soil', 'bannister;banister;balustrade;balusters;handrail', 'escalator;moving;staircase;moving;stairway', 'ottoman;pouf;pouffe;puff;hassock', 'bottle', 'buffet;counter;sideboard', 'poster;posting;placard;notice;bill;card', 'stage', 'van', 'ship', 'fountain', 'conveyer;belt;conveyor;belt;conveyer;conveyor;transporter', 'canopy', 'washer;automatic;washer;washing;machine', 'plaything;toy', 'swimming;pool;swimming;bath;natatorium', 'stool', 'barrel;cask', 'basket;handbasket', 'waterfall;falls', 'tent;collapsible;shelter', 'bag', 'minibike;motorbike', 'cradle', 'oven', 'ball', 'food;solid;food', 'step;stair', 'tank;storage;tank', 'trade;name;brand;name;brand;marque', 'microwave;microwave;oven', 'pot;flowerpot', 'animal;animate;being;beast;brute;creature;fauna', 'bicycle;bike;wheel;cycle', 'lake', 'dishwasher;dish;washer;dishwashing;machine', 'screen;silver;screen;projection;screen', 'blanket;cover', 'sculpture', 'hood;exhaust;hood', 'sconce', 'vase', 'traffic;light;traffic;signal;stoplight', 'tray', 'ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin', 'fan', 'pier;wharf;wharfage;dock', 'crt;screen', 'plate', 'monitor;monitoring;device', 'bulletin;board;notice;board', 'shower', 'radiator', 'glass;drinking;glass', 'clock', 'flag'] def label_to_color_image(label): """Adds color defined by the dataset colormap to the label. Args: label: A 2D array with integer type, storing the segmentation label. Returns: result: A 2D array with floating type. The element of the array is the color indexed by the corresponding element in the input label to the PASCAL color map. Raises: ValueError: If label is not of rank 2 or its value is larger than color map maximum entry. """ if label.ndim != 2: raise ValueError("Expect 2-D input label") colormap = np.asarray(ade_palette()) if np.max(label) >= len(colormap): raise ValueError("label value too large.") return colormap[label] def draw_plot(pred_img, seg): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis('off') LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg.numpy().astype("uint8")) ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def sepia(input_img): input_img = Image.fromarray(input_img) inputs = feature_extractor(images=input_img, return_tensors="tf") outputs = model(**inputs) logits = outputs.logits logits = tf.transpose(logits, [0, 2, 3, 1]) logits = tf.image.resize( logits, input_img.size[::-1] ) # We reverse the shape of `image` because `image.size` returns width and height. seg = tf.math.argmax(logits, axis=-1)[0] color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 ) # height, width, 3 palette = np.array(ade_palette()) for label, color in enumerate(palette): color_seg[seg == label, :] = color # Convert to BGR color_seg = color_seg[..., ::-1] # Show image + mask pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 pred_img = pred_img.astype(np.uint8) fig = draw_plot(pred_img, seg) return fig demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), outputs=['plot'], examples=["ADE_val_00000001.jpeg"]) demo.launch()