aliabd HF staff commited on
Commit
cc194ea
1 Parent(s): a3ee9e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -43
app.py CHANGED
@@ -1,53 +1,50 @@
1
  # URL: https://huggingface.co/spaces/gradio/image_segmentation/
2
  # imports
3
  import gradio as gr
4
- from transformers import DetrFeatureExtractor, DetrForSegmentation
5
- from PIL import Image
6
- import numpy as np
7
  import torch
8
- import torchvision
9
- import itertools
10
- import seaborn as sns
 
11
 
12
- # load model from hugging face
13
- feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
14
- model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
 
 
15
 
16
- def predict_animal_mask(im,
17
- gr_slider_confidence):
18
- image = Image.fromarray(im)
19
- image = image.resize((200,200))
20
- encoding = feature_extractor(images=image, return_tensors="pt")
21
- outputs = model(**encoding)
22
- logits = outputs.logits
23
- bboxes = outputs.pred_boxes
24
- masks = outputs.pred_masks
25
- prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0]
26
- keep = prob_per_query > gr_slider_confidence/100.0
27
- label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy()
28
- color_mask = np.zeros(image.size+(3,))
29
- palette = itertools.cycle(sns.color_palette())
30
- for lbl in np.unique(label_per_pixel):
31
- color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255
32
- pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75
33
- pred_img = pred_img.astype(np.uint8)
34
- return pred_img
35
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # define inputs
38
- gr_image_input = gr.inputs.Image()
39
- gr_slider_confidence = gr.inputs.Slider(0,100,5,85,
40
- label='Set confidence threshold for masks')
41
- # define output
42
- gr_image_output = gr.outputs.Image()
43
 
44
- # define interface
45
- demo = gr.Interface(predict_animal_mask,
46
- inputs = [gr_image_input,gr_slider_confidence],
47
- outputs = gr_image_output,
48
- title = 'Image segmentation with varying confidence',
49
- description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone",
50
- examples=[["cheetah.jpg", 75], ["lion.jpg", 85]])
51
 
52
- # launch
53
- demo.launch()
 
1
  # URL: https://huggingface.co/spaces/gradio/image_segmentation/
2
  # imports
3
  import gradio as gr
 
 
 
4
  import torch
5
+ import random
6
+ import numpy as np
7
+ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
8
+
9
 
10
+ # load model
11
+ device = torch.device("cpu")
12
+ model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
13
+ model.eval()
14
+ preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
15
 
16
+ # define core and helper fns
17
+ def visualize_instance_seg_mask(mask):
18
+ image = np.zeros((mask.shape[0], mask.shape[1], 3))
19
+ labels = np.unique(mask)
20
+ label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
21
+ for i in range(image.shape[0]):
22
+ for j in range(image.shape[1]):
23
+ image[i, j, :] = label2color[mask[i, j]]
24
+ image = image / 255
25
+ return image
 
 
 
 
 
 
 
 
 
26
 
27
+ def query_image(img):
28
+ target_size = (img.shape[0], img.shape[1])
29
+ inputs = preprocessor(images=img, return_tensors="pt")
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ outputs.class_queries_logits = outputs.class_queries_logits.cpu()
33
+ outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
34
+ results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
35
+ results = torch.argmax(results, dim=0).numpy()
36
+ results = visualize_instance_seg_mask(results)
37
+ return results
38
 
39
+ # define interface
 
 
 
 
 
40
 
41
+ demo = gr.Interface(
42
+ query_image,
43
+ inputs=[gr.Image()],
44
+ outputs="image",
45
+ title="MaskFormer Demo",
46
+ examples=["example_1.png", "example_2.png"]
47
+ )
48
 
49
+ # launch
50
+ demo.launch()