# -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch import numpy as np from PIL import Image from torchvision import transforms from detectron2.data import MetadataCatalog from xdecoder.language.loss import vl_similarity t = [] t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) transform_ret = transforms.Compose(t) t = [] t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) transform_grd = transforms.Compose(t) metedata = MetadataCatalog.get('coco_2017_train_panoptic') def text_retrieval(model, image, texts, inpainting_text, *args, **kwargs): out_str = '' with torch.no_grad(): image = transform_ret(image) image = np.asarray(image) images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() batch_inputs = [{'image': images, 'image_id': 0}] outputs = model.model.evaluate(batch_inputs) v_emb = torch.cat([x['captions'][-1:] for x in outputs]) v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) texts = [x.strip() for x in texts.split(',')] model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, is_eval=False, name='caption', prompt=False) t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption')) temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale logits = vl_similarity(v_emb, t_emb, temperature) topk_prob, topk_idx = logits.softmax(-1)[0].topk(min(5, len(texts))) for prob, idx in zip(topk_prob, topk_idx): out_str += "{}:{:.2f}; ".format(texts[idx.item()], prob.item()) torch.cuda.empty_cache() return None, out_str, None