File size: 4,410 Bytes
f9e4a95
 
 
 
 
562224f
 
 
 
912be9c
 
272aa37
 
562224f
 
 
6c333c9
 
 
 
 
 
 
272aa37
f9e4a95
 
 
 
 
 
9267ec1
 
f9e4a95
6c333c9
 
f9e4a95
562224f
912be9c
17145cb
 
 
562224f
 
 
6c333c9
f9e4a95
6c333c9
562224f
 
 
912be9c
 
 
f9e4a95
4f1cf17
562224f
f9e4a95
6c333c9
 
 
 
 
 
 
f9e4a95
 
272aa37
f9e4a95
 
17145cb
 
562224f
 
 
 
 
 
 
912be9c
 
562224f
 
f9e4a95
 
 
 
912be9c
 
 
6c333c9
f9e4a95
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Using as reference:
- https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
- https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
- https://huggingface.co/facebook/detr-resnet-50-panoptic
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/

https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb

https://arxiv.org/abs/2005.12872

https://arxiv.org/pdf/1801.00868.pdf 

Additions
- add shown labels as strings
- show only animal masks (ask an nlp model?)

For next time
- for diff 'confidence' the high conf masks should change....
- colors are not great and should be constant per class? add text?
- Im getting core dumped (segmentation fault) when loading hugging face model.. :()
    https://github.com/huggingface/transformers/issues/16939 
- cap slider to 95?
- switch between panoptic and semantic?
"""

from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision

import itertools
import seaborn as sns

def predict_animal_mask(im,
                        gr_slider_confidence):
    image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
    image = image.resize((200,200)) #  PIL image # could I upsample output instead? better?

    # encoding is a dict with pixel_values and pixel_mask
    encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
    outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state']
    logits = outputs.logits # torch.Size([1, 100, 251]); class logits? but  why 251?
    bboxes = outputs.pred_boxes
    masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); mask logits? for every pixel, score in each of the 100 classes? there is a mask per class

    # keep only the masks with high confidence?--------------------------------
    # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
    prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251?
    # threshold the confidence
    keep = prob_per_query > gr_slider_confidence/100.0

    # postprocess the mask (numpy arrays)
    label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
    color_mask = np.zeros(image.size+(3,))
    palette = itertools.cycle(sns.color_palette())
    for lbl in np.unique(label_per_pixel): #enumerate(palette()):
        color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 #color

    # color_mask = np.zeros(image.size+(3,))
    # for lbl, color in enumerate(ade_palette()):
    #     color_mask[label_per_pixel==lbl,:] = color

    # Show image + mask
    pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75
    pred_img = pred_img.astype(np.uint8)   

    return pred_img

#######################################
# get models from hugging face
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')

# gradio components -inputs
gr_image_input = gr.inputs.Image()
gr_slider_confidence = gr.inputs.Slider(0,100,5,85,
                                        label='Set confidence threshold for masks')
# gradio outputs
gr_image_output = gr.outputs.Image() 

####################################################
# Create user interface and launch
gr.Interface(predict_animal_mask, 
                inputs = [gr_image_input,gr_slider_confidence],
                outputs = gr_image_output,
                title = 'Image segmentation with varying confidence',
                description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()


####################################
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)