aliabd HF staff commited on
Commit
650f85a
1 Parent(s): a37b37c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # URL: https://huggingface.co/spaces/gradio/image_segmentation/
2
+ # imports
3
+ import gradio as gr
4
+ from transformers import DetrFeatureExtractor, DetrForSegmentation
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ import itertools
10
+ import seaborn as sns
11
+
12
+ # load model from hugging face
13
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
14
+ model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
15
+
16
+ def predict_animal_mask(im,
17
+ gr_slider_confidence):
18
+ image = Image.fromarray(im)
19
+ image = image.resize((200,200))
20
+ encoding = feature_extractor(images=image, return_tensors="pt")
21
+ outputs = model(**encoding)
22
+ logits = outputs.logits
23
+ bboxes = outputs.pred_boxes
24
+ masks = outputs.pred_masks
25
+ prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0]
26
+ keep = prob_per_query > gr_slider_confidence/100.0
27
+ label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy()
28
+ color_mask = np.zeros(image.size+(3,))
29
+ palette = itertools.cycle(sns.color_palette())
30
+ for lbl in np.unique(label_per_pixel):
31
+ color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255
32
+ pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75
33
+ pred_img = pred_img.astype(np.uint8)
34
+ return pred_img
35
+
36
+
37
+ # define inputs
38
+ gr_image_input = gr.inputs.Image()
39
+ gr_slider_confidence = gr.inputs.Slider(0,100,5,85,
40
+ label='Set confidence threshold for masks')
41
+ # define output
42
+ gr_image_output = gr.outputs.Image()
43
+
44
+ # define interface
45
+ demo = gr.Interface(predict_animal_mask,
46
+ inputs = [gr_image_input,gr_slider_confidence],
47
+ outputs = gr_image_output,
48
+ title = 'Image segmentation with varying confidence',
49
+ description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone",
50
+ examples=[["cheetah.jpg", 75], ["lion.jpg", 85]])
51
+
52
+ # launch
53
+ demo.launch()