# URL: https://huggingface.co/spaces/gradio/image_segmentation/ # imports import gradio as gr from transformers import DetrFeatureExtractor, DetrForSegmentation from PIL import Image import numpy as np import torch import torchvision import itertools import seaborn as sns # load model from hugging face feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic') model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic') def predict_animal_mask(im, gr_slider_confidence): image = Image.fromarray(im) image = image.resize((200,200)) encoding = feature_extractor(images=image, return_tensors="pt") outputs = model(**encoding) logits = outputs.logits bboxes = outputs.pred_boxes masks = outputs.pred_masks prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] keep = prob_per_query > gr_slider_confidence/100.0 label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() color_mask = np.zeros(image.size+(3,)) palette = itertools.cycle(sns.color_palette()) for lbl in np.unique(label_per_pixel): color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75 pred_img = pred_img.astype(np.uint8) return pred_img # define inputs gr_image_input = gr.inputs.Image() gr_slider_confidence = gr.inputs.Slider(0,100,5,85, label='Set confidence threshold for masks') # define output gr_image_output = gr.outputs.Image() # define interface demo = gr.Interface(predict_animal_mask, inputs = [gr_image_input,gr_slider_confidence], outputs = gr_image_output, title = 'Image segmentation with varying confidence', description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone", examples=[["cheetah.jpg", 75], ["lion.jpg", 85]]) # launch demo.launch()