import gradio as gr from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation import matplotlib.pyplot as plt from matplotlib import gridspec import numpy as np from PIL import Image import tensorflow as tf import requests # Load the pre-trained model and feature extractor feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280") def my_palette(): return [ [131, 162, 255], [180, 189, 255], [255, 227, 187], [255, 210, 143], [248, 117, 170], [255, 223, 223], [255, 246, 246], [174, 222, 252], [150, 194, 145], [255, 219, 170], [244, 238, 238], [50, 38, 83], [128, 98, 214], [146, 136, 248], [255, 210, 215], [255, 152, 152], [162, 103, 138], [63, 29, 56], [0,0,0] ] labels_list = [] with open(r"labels.txt", "r") as fp: for line in fp: labels_list.append(line[:-1]) colormap = np.asarray(my_palette()) def greet(input_img): inputs = feature_extractor(images=input_img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits logits_tf = tf.transpose(logits.detach(), [0, 2, 3, 1]) logits_tf = tf.image.resize( logits_tf, [640, 1280] ) seg = tf.math.argmax(logits_tf, axis=-1)[0] color_seg = label_to_color_image(seg.numpy()) # Resize color_seg to match the shape of input_img color_seg_resized = tf.image.resize(color_seg, (input_img.shape[0], input_img.shape[1])) pred_img = np.array(input_img) * 0.5 + color_seg_resized * 0.5 # Convert pred_img to NumPy array and then change data type pred_img = np.array(pred_img).astype(np.uint8) fig = draw_plot(pred_img, seg.numpy()) return fig def draw_plot(pred_img, seg): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis("off") LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) # Limit unique_labels to be within the range of colormap unique_labels = np.unique(seg.astype("uint8")) unique_labels = unique_labels[unique_labels < len(FULL_COLOR_MAP)] ax = plt.subplot(grid_spec[1]) if len(unique_labels) > 0: plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) else: # Handle case when there are no unique labels plt.imshow(np.zeros((1, 1, 3), dtype=np.uint8)) ax.yaxis.tick_right() plt.yticks([], []) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") # Clip label values to be within the range of colormap label = np.clip(label, 0, len(colormap) - 1) return colormap[label] iface = gr.Interface( fn=greet, inputs="image", outputs=["plot"], examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"], allow_flagging="never" ) iface.launch(share=True)