import gradio as gr from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation import matplotlib.pyplot as plt from matplotlib import gridspec import numpy as np from PIL import Image import tensorflow as tf import requests # Load the pre-trained model and feature extractor feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") def my_palette(): return [ [131, 162, 255], [180, 189, 255], [255, 227, 187], [255, 210, 143], [248, 117, 170], [255, 223, 223], [255, 246, 246], [174, 222, 252], [150, 194, 145], [255, 219, 170], [244, 238, 238], [50, 38, 83], [128, 98, 214], [146, 136, 248], [255, 210, 215], [255, 152, 152], [162, 103, 138], [63, 29, 56] ] labels_list = [] with open(r"labels.txt", "r") as fp: for line in fp: labels_list.append(line[:-1]) colormap = np.asarray(my_palette()) def greet(input_img): inputs = feature_extractor(images=input_img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits logits_tf = tf.transpose(logits, [0, 2, 3, 1]) logits_tf = tf.image.resize( logits_tf, input_img.size[::-1] ) seg = tf.math.argmax(logits_tf, axis=-1)[0] color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 ) # height, width, 3 for label, color in enumerate(colormap): color_seg[seg.numpy() == label, :] = color pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 pred_img = pred_img.astype(np.uint8) fig = draw_plot(pred_img, seg) return fig def draw_plot(pred_img, seg): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis("off") LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg.astype("uint8")) ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") if np.max(label) >= len(colormap): raise ValueError("label value too large.") return colormap[label] iface = gr.Interface( fn=greet, inputs="image", outputs=["plot"], examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"], allow_flagging="never" ) iface.launch(share=True)