Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation | |
| import matplotlib.pyplot as plt | |
| from matplotlib import gridspec | |
| feature_extractor = SegformerFeatureExtractor.from_pretrained( | |
| "nvidia/segformer-b0-finetuned-cityscapes-1024-1024" | |
| ) | |
| model = TFSegformerForSemanticSegmentation.from_pretrained( | |
| "nvidia/segformer-b0-finetuned-cityscapes-1024-1024" | |
| ) | |
| def ade_palette(): | |
| """ADE20K palette that maps each class to RGB values.""" | |
| return [ | |
| [255, 0, 0], | |
| [255, 187, 0], | |
| [255, 228, 0], | |
| [29, 219, 22], | |
| [178, 204, 255], | |
| [1, 0, 255], | |
| [165, 102, 255], | |
| [217, 65, 197], | |
| [116, 116, 116], | |
| [204, 114, 61], | |
| [206, 242, 121], | |
| [61, 183, 204], | |
| [94, 94, 94], | |
| [196, 183, 59], | |
| [246, 246, 246], | |
| [209, 178, 255], | |
| [0, 87, 102], | |
| [153, 0, 76], | |
| [47, 157, 39] | |
| ] | |
| labels_list = [] | |
| with open(r'labels.txt', 'r') as fp: | |
| for line in fp: | |
| labels_list.append(line[:-1]) | |
| colormap = np.asarray(ade_palette()) | |
| def label_to_color_image(label): | |
| if label.ndim != 2: | |
| raise ValueError("Expect 2-D input label") | |
| if np.max(label) >= len(colormap): | |
| raise ValueError("label value too large.") | |
| return colormap[label] | |
| def draw_plot(pred_img, seg): | |
| fig = plt.figure(figsize=(20, 15)) | |
| grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) | |
| plt.subplot(grid_spec[0]) | |
| plt.imshow(pred_img) | |
| plt.axis('off') | |
| LABEL_NAMES = np.asarray(labels_list) | |
| FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) | |
| FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) | |
| unique_labels = np.unique(seg.numpy().astype("uint8")) | |
| ax = plt.subplot(grid_spec[1]) | |
| plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") | |
| ax.yaxis.tick_right() | |
| plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) | |
| plt.xticks([], []) | |
| ax.tick_params(width=0.0, labelsize=25) | |
| return fig | |
| def sepia(input_img): | |
| input_img = Image.fromarray(input_img) | |
| inputs = feature_extractor(images=input_img, return_tensors="tf") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| logits = tf.transpose(logits, [0, 2, 3, 1]) | |
| logits = tf.image.resize(logits, input_img.size[::-1]) | |
| seg = tf.math.argmax(logits, axis=-1)[0] | |
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
| for label, color in enumerate(colormap): | |
| color_seg[seg.numpy() == label, :] = color | |
| pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 | |
| pred_img = pred_img.astype(np.uint8) | |
| fig = draw_plot(pred_img, seg) | |
| # ๊ฐ ๋ฌผ์ฒด์ ๋ํ ์์ธก ํด๋์ค์ ํ๋ฅ ์ป๊ธฐ | |
| unique_labels = np.unique(seg.numpy().astype("uint8")) | |
| class_probabilities = {} | |
| for label in unique_labels: | |
| mask = (seg.numpy() == label) | |
| class_name = labels_list[label] | |
| class_prob = tf.nn.softmax(logits.numpy()[0][:, :, label]) # softmax ์ ์ฉ | |
| class_prob = np.mean(class_prob[mask]) | |
| class_probabilities[class_name] = class_prob * 100 # ๋ฐฑ๋ถ์จ๋ก ๋ณํ | |
| # Gradio Interface์ ์ถ๋ ฅํ ๋ฌธ์์ด ์์ฑ | |
| output_text = "Predicted class probabilities:\n" | |
| for class_name, prob in class_probabilities.items(): | |
| output_text += f"{class_name}: {prob:.2f}%\n" | |
| # ์ ํ์ฑ์ด ๊ฐ์ฅ ๋์ ๋ฌผ์ฒด ์ ๋ณด ์ถ๋ ฅ | |
| max_prob_class = max(class_probabilities, key=class_probabilities.get) | |
| max_prob_value = class_probabilities[max_prob_class] | |
| output_text += f"\nPredicted class with highest probability: {max_prob_class} \n Probability: {max_prob_value:.4f}%" | |
| return fig, output_text | |
| demo = gr.Interface(fn=sepia, | |
| inputs=gr.Image(shape=(400, 600)), | |
| outputs=['plot', 'text'], | |
| examples=["citiscapes-1.jpeg", "citiscapes-2.jpeg", "citiscapes-3.jpeg", "citiscapes-4.jpeg"], | |
| allow_flagging='never') | |
| demo.launch() | |