import gradio as gr # from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation import matplotlib.pyplot as plt from matplotlib import gridspec from PIL import Image import numpy as np import tensorflow as tf import requests # feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") urls = ["http://farm3.staticflickr.com/2523/3705549787_79049b1b6d_z.jpg", "http://farm8.staticflickr.com/7012/6476201279_52db36af64_z.jpg", "http://farm8.staticflickr.com/7180/6967423255_a3d65d5f6b_z.jpg", "http://farm4.staticflickr.com/3563/3470840644_3378804bea_z.jpg", "http://farm9.staticflickr.com/8388/8516454091_0ebdc1130a_z.jpg"] images = [] for i in urls: images.append(Image.open(requests.get(i, stream=True).raw)) # inputs = feature_extractor(images=image, return_tensors="pt") # outputs = model(**inputs) # logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4) def my_palette(): return [ [131, 162, 255], [180, 189, 255], [255, 227, 187], [255, 210, 143], [248, 117, 170], [255, 223, 223], [255, 246, 246], [174, 222, 252], [150, 194, 145], [255, 219, 170], [244, 238, 238], [50, 38, 83], [128, 98, 214], [146, 136, 248], [255, 210, 215], [255, 152, 152], [162, 103, 138], [63, 29, 56] ] labels_list = [] with open(r"labels.txt", "r") as fp: for line in fp: labels_list.append(line[:-1]) colormap = np.asarray(my_palette()) def greet(input_img): inputs = feature_extractor(images=input_img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits logits = tf.transpose(logits, [0, 2, 3, 1]) logits = tf.image.resize( logits, input_img.size[::-1] ) # We reverse the shape of `image` because `image.size` returns width and height. seg = tf.math.argmax(logits, axis=-1)[0] color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 ) # height, width, 3 for label, color in enumerate(colormap): color_seg[seg.numpy() == label, :] = color # Show image + mask pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 pred_img = pred_img.astype(np.uint8) fig = draw_plot(pred_img, seg) return fig def draw_plot(pred_img, seg): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis("off") LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg.numpy().astype("uint8")) ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") if np.max(label) >= len(colormap): raise ValueError("label value too large.") return colormap[label] iface = gr.Interface( fn=greet, inputs=gr.Image(shape=(640, 1280)), outputs=["plot"], examples=[images], allow_flagging="never") iface.launch(share=True)