fffiloni commited on
Commit
d561f97
1 Parent(s): a6ba22e

Create open_inst.py

Browse files
Files changed (1) hide show
  1. tasks/open_inst.py +60 -0
tasks/open_inst.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ from utils.visualizer import Visualizer
13
+ from detectron2.utils.colormap import random_color
14
+ from detectron2.data import MetadataCatalog
15
+ from detectron2.structures import BitMasks
16
+
17
+
18
+ t = []
19
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
20
+ transform = transforms.Compose(t)
21
+ metadata = MetadataCatalog.get('ade20k_panoptic_train')
22
+
23
+ def open_instseg(model, image, texts, inpainting_text, *args, **kwargs):
24
+ thing_classes = [x.strip() for x in texts.split(',')]
25
+ thing_colors = [random_color(rgb=True, maximum=255).astype(np.int32).tolist() for _ in range(len(thing_classes))]
26
+ thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}
27
+
28
+ MetadataCatalog.get("demo").set(
29
+ thing_colors=thing_colors,
30
+ thing_classes=thing_classes,
31
+ thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
32
+ )
33
+
34
+ with torch.no_grad():
35
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=True)
36
+
37
+ metadata = MetadataCatalog.get('demo')
38
+ model.model.metadata = metadata
39
+ model.model.sem_seg_head.num_classes = len(thing_classes)
40
+
41
+ image_ori = transform(image)
42
+ width = image_ori.size[0]
43
+ height = image_ori.size[1]
44
+ image = np.asarray(image_ori)
45
+ images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
46
+
47
+ batch_inputs = [{'image': images, 'height': height, 'width': width}]
48
+ outputs = model.forward(batch_inputs)
49
+ visual = Visualizer(image_ori, metadata=metadata)
50
+
51
+ inst_seg = outputs[-1]['instances']
52
+ inst_seg.pred_masks = inst_seg.pred_masks.cpu()
53
+ inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes()
54
+ demo = visual.draw_instance_predictions(inst_seg) # rgb Image
55
+ res = demo.get_image()
56
+
57
+
58
+ MetadataCatalog.remove('demo')
59
+ torch.cuda.empty_cache()
60
+ return Image.fromarray(res), '', None