fffiloni commited on
Commit
2feea8a
1 Parent(s): 2dc6006

Create ref_seg.py

Browse files
Files changed (1) hide show
  1. tasks/ref_seg.py +46 -0
tasks/ref_seg.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ from utils.visualizer import Visualizer
13
+ from detectron2.utils.colormap import random_color
14
+ from detectron2.data import MetadataCatalog
15
+
16
+
17
+ t = []
18
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
19
+ transform = transforms.Compose(t)
20
+ metadata = MetadataCatalog.get('ade20k_panoptic_train')
21
+
22
+ def referring_segmentation(model, image, texts, inpainting_text, *args, **kwargs):
23
+ model.model.metadata = metadata
24
+ texts = texts.strip()
25
+ texts = [[text.strip() if text.endswith('.') else (text + '.')] for text in texts.split(',')]
26
+ image_ori = transform(image)
27
+
28
+ with torch.no_grad():
29
+ width = image_ori.size[0]
30
+ height = image_ori.size[1]
31
+ image = np.asarray(image_ori)
32
+ image_ori_np = np.asarray(image_ori)
33
+ images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
34
+
35
+ batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}]
36
+ outputs = model.model.evaluate_grounding(batch_inputs, None)
37
+ visual = Visualizer(image_ori_np, metadata=metadata)
38
+
39
+ grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
40
+ for idx, mask in enumerate(grd_mask):
41
+ color = random_color(rgb=True, maximum=1).astype(np.int32).tolist()
42
+ demo = visual.draw_binary_mask(mask, color=color, text=texts[idx])
43
+ res = demo.get_image()
44
+
45
+ torch.cuda.empty_cache()
46
+ return Image.fromarray(res), '', None