fffiloni commited on
Commit
8858507
1 Parent(s): 2feea8a

Create reg_ret.py

Browse files
Files changed (1) hide show
  1. tasks/reg_ret.py +72 -0
tasks/reg_ret.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import glob
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from detectron2.data import MetadataCatalog
15
+ from utils.visualizer import Visualizer
16
+ from xdecoder.language.loss import vl_similarity
17
+ from detectron2.utils.colormap import random_color
18
+
19
+
20
+ t = []
21
+ t.append(transforms.Resize((224,224), interpolation=Image.BICUBIC))
22
+ transform_ret = transforms.Compose(t)
23
+ t = []
24
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
25
+ transform_grd = transforms.Compose(t)
26
+ metadata = MetadataCatalog.get('coco_2017_train_panoptic')
27
+
28
+ imgs_root = 'images/coco'
29
+ img_pths = sorted(glob.glob(os.path.join(imgs_root, '*.jpg')))
30
+ imgs = [Image.open(x).convert('RGB') for x in img_pths]
31
+ v_emb = torch.load("v_emb.da")
32
+
33
+ def region_retrieval(model, image, texts, inpainting_text, *args, **kwargs):
34
+ model_novg, model_seg = model
35
+ with torch.no_grad():
36
+ # images = [transform_ret(x) for x in imgs]
37
+ # images = [np.asarray(x) for x in imgs]
38
+ # images = [torch.from_numpy(x.copy()).permute(2,0,1).cuda() for x in images]
39
+ # batch_inputs = [{'image': image, 'image_id': 0} for image in images]
40
+ # outputs = model_novg.model.evaluate(batch_inputs)
41
+ # v_emb = torch.cat([x['captions'][-1:] for x in outputs])
42
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
43
+ # torch.save(v_emb, "v_emb.da")
44
+ # exit()
45
+
46
+ texts_ = [[x.strip() if x.strip().endswith('.') else (x.strip() + '.')] for x in texts.split(',')]
47
+ model_novg.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts_, is_eval=False, name='caption', prompt=False)
48
+ t_emb = getattr(model_novg.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption'))
49
+ temperature = model_novg.model.sem_seg_head.predictor.lang_encoder.logit_scale
50
+
51
+ logits = vl_similarity(v_emb, t_emb, temperature)
52
+ prob, idx = logits[:,0].softmax(-1).max(0)
53
+ image_ori = imgs[idx]
54
+ image = transform_grd(image_ori)
55
+ width, height = image.size
56
+ image = np.asarray(image)
57
+ image_ori = np.asarray(image)
58
+ images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
59
+ batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts_}}]
60
+ model_seg.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts_, is_eval=False, name='caption', prompt=False)
61
+ outputs = model_seg.model.evaluate_grounding(batch_inputs, None)
62
+
63
+ visual = Visualizer(image_ori, metadata=metadata)
64
+ grd_masks = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
65
+
66
+ for text, mask in zip([x[0] for x in texts_], grd_masks):
67
+ color = random_color(rgb=True, maximum=1).astype(np.int32).tolist()
68
+ demo = visual.draw_binary_mask(mask, color=color, text=texts, alpha=0.5)
69
+ res = demo.get_image()
70
+
71
+ torch.cuda.empty_cache()
72
+ return Image.fromarray(res), "Selected Image Probability: {:.2f}".format(prob.item()), None