fffiloni commited on
Commit
6a1cd92
1 Parent(s): 8858507

Create text_ret.py

Browse files
Files changed (1) hide show
  1. tasks/text_ret.py +46 -0
tasks/text_ret.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ from detectron2.data import MetadataCatalog
13
+ from xdecoder.language.loss import vl_similarity
14
+
15
+
16
+ t = []
17
+ t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
18
+ transform_ret = transforms.Compose(t)
19
+ t = []
20
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
21
+ transform_grd = transforms.Compose(t)
22
+
23
+ metedata = MetadataCatalog.get('coco_2017_train_panoptic')
24
+
25
+ def text_retrieval(model, image, texts, inpainting_text, *args, **kwargs):
26
+ out_str = ''
27
+ with torch.no_grad():
28
+ image = transform_ret(image)
29
+ image = np.asarray(image)
30
+ images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
31
+ batch_inputs = [{'image': images, 'image_id': 0}]
32
+ outputs = model.model.evaluate(batch_inputs)
33
+ v_emb = torch.cat([x['captions'][-1:] for x in outputs])
34
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
35
+
36
+ texts = [x.strip() for x in texts.split(',')]
37
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, is_eval=False, name='caption', prompt=False)
38
+ t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption'))
39
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
40
+ logits = vl_similarity(v_emb, t_emb, temperature)
41
+ topk_prob, topk_idx = logits.softmax(-1)[0].topk(min(5, len(texts)))
42
+
43
+ for prob, idx in zip(topk_prob, topk_idx):
44
+ out_str += "{}:{:.2f}; ".format(texts[idx.item()], prob.item())
45
+ torch.cuda.empty_cache()
46
+ return None, out_str, None