Spaces:
Running
Running
# URL: https://huggingface.co/spaces/gradio/image_segmentation/ | |
# imports | |
import gradio as gr | |
from transformers import DetrFeatureExtractor, DetrForSegmentation | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torchvision | |
import itertools | |
import seaborn as sns | |
# load model from hugging face | |
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic') | |
model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic') | |
def predict_animal_mask(im, | |
gr_slider_confidence): | |
image = Image.fromarray(im) | |
image = image.resize((200,200)) | |
encoding = feature_extractor(images=image, return_tensors="pt") | |
outputs = model(**encoding) | |
logits = outputs.logits | |
bboxes = outputs.pred_boxes | |
masks = outputs.pred_masks | |
prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] | |
keep = prob_per_query > gr_slider_confidence/100.0 | |
label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() | |
color_mask = np.zeros(image.size+(3,)) | |
palette = itertools.cycle(sns.color_palette()) | |
for lbl in np.unique(label_per_pixel): | |
color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 | |
pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75 | |
pred_img = pred_img.astype(np.uint8) | |
return pred_img | |
# define inputs | |
gr_image_input = gr.inputs.Image() | |
gr_slider_confidence = gr.inputs.Slider(0,100,5,85, | |
label='Set confidence threshold for masks') | |
# define output | |
gr_image_output = gr.outputs.Image() | |
# define interface | |
demo = gr.Interface(predict_animal_mask, | |
inputs = [gr_image_input,gr_slider_confidence], | |
outputs = gr_image_output, | |
title = 'Image segmentation with varying confidence', | |
description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone", | |
examples=[["cheetah.jpg", 75], ["lion.jpg", 85]]) | |
# launch | |
demo.launch() | |