""" Using as reference: - https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512 - https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py - https://huggingface.co/facebook/detr-resnet-50-panoptic # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb https://arxiv.org/abs/2005.12872 https://arxiv.org/pdf/1801.00868.pdf Additions - add shown labels as strings - show only animal masks (ask an nlp model?) For next time - for diff 'confidence' the high conf masks should change.... - colors are not great and should be constant per class? add text? - Im getting core dumped (segmentation fault) when loading hugging face model.. :() https://github.com/huggingface/transformers/issues/16939 - cap slider to 95? - switch between panoptic and semantic? """ from transformers import DetrFeatureExtractor, DetrForSegmentation from PIL import Image import gradio as gr import numpy as np import torch import torchvision import itertools import seaborn as sns def predict_animal_mask(im, gr_slider_confidence): image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image image = image.resize((200,200)) # PIL image # could I upsample output instead? better? # encoding is a dict with pixel_values and pixel_mask encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state'] logits = outputs.logits # torch.Size([1, 100, 251]); class logits? but why 251? bboxes = outputs.pred_boxes masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); mask logits? for every pixel, score in each of the 100 classes? there is a mask per class # keep only the masks with high confidence?-------------------------------- # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one) prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251? # threshold the confidence keep = prob_per_query > gr_slider_confidence/100.0 # postprocess the mask (numpy arrays) label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel color_mask = np.zeros(image.size+(3,)) palette = itertools.cycle(sns.color_palette()) for lbl in np.unique(label_per_pixel): #enumerate(palette()): color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 #color # color_mask = np.zeros(image.size+(3,)) # for lbl, color in enumerate(ade_palette()): # color_mask[label_per_pixel==lbl,:] = color # Show image + mask pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75 pred_img = pred_img.astype(np.uint8) return pred_img ####################################### # get models from hugging face feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic') model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic') # gradio components -inputs gr_image_input = gr.inputs.Image() gr_slider_confidence = gr.inputs.Slider(0,100,5,85, label='Set confidence threshold for masks') # gradio outputs gr_image_output = gr.outputs.Image() #################################################### # Create user interface and launch gr.Interface(predict_animal_mask, inputs = [gr_image_input,gr_slider_confidence], outputs = gr_image_output, title = 'Image segmentation with varying confidence', description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch() #################################### # url = "http://images.cocodataset.org/val2017/000000039769.jpg" # image = Image.open(requests.get(url, stream=True).raw) # inputs = feature_extractor(images=image, return_tensors="pt") # outputs = model(**inputs) # logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)