import gradio as gr import glob import torch import pickle from PIL import Image, ImageDraw import numpy as np from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from scipy.ndimage import center_of_mass def combine_ims(im1, im2, val=128): p = Image.new("L", im1.size, val) im = Image.composite(im1, im2, p) return im def get_class_centers(segmentation_mask, class_dict): segmentation_mask = segmentation_mask.numpy() + 1 class_centers = {} for class_index, _ in class_dict.items(): class_mask = (segmentation_mask == class_index).astype(int) center_of_mass_list = center_of_mass(class_mask) class_centers[class_index] = center_of_mass_list class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))} return class_centers def visualize_mask(predicted_semantic_map, class_ids, class_colors): h, w = predicted_semantic_map.shape color_indexes = np.zeros((h, w), dtype=np.uint8) color_indexes[:] = predicted_semantic_map.numpy() color_indexes = color_indexes.flatten() colors = class_colors[class_ids[color_indexes]] output = colors.reshape(h, w, 3).astype(np.uint8) image_mask = Image.fromarray(output) return image_mask def get_out_image(image, predicted_semantic_map): class_centers = get_class_centers(predicted_semantic_map, class_dict) mask = visualize_mask(predicted_semantic_map, class_ids, class_colors) image_mask = combine_ims(image, mask, val=128) draw = ImageDraw.Draw(image_mask) extracted_tags = [] for id, (y, x) in class_centers.items(): class_name = str(class_names[id - 1]) extracted_tags.append(class_name) # Append only the class name draw.text((x, y), class_name, fill='black') # Joining all tags into a single string separated by " | " tags_string = " | ".join(extracted_tags) return image_mask, tags_string def gradio_process(image): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] out_image, extracted_tags = get_out_image(image, predicted_semantic_map) return out_image, extracted_tags with open('ade20k_classes.pickle', 'rb') as f: class_names, class_ids, class_colors = pickle.load(f) class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors) class_dict = dict(zip(class_ids, class_names)) device = torch.device("cpu") processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device) model.eval() demo = gr.Interface( gradio_process, inputs=gr.inputs.Image(type="pil"), outputs=[gr.outputs.Image(type="pil"), gr.outputs.Textbox()], title="Semantic Segmentation", examples=glob.glob('./examples/*.jpg'), allow_flagging="never", ) demo.launch()