# -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch import numpy as np from PIL import Image from torchvision import transforms from utils.visualizer import Visualizer from detectron2.utils.colormap import random_color from detectron2.data import MetadataCatalog t = [] t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) transform = transforms.Compose(t) metadata = MetadataCatalog.get('ade20k_panoptic_train') def referring_segmentation(model, image, texts, inpainting_text, *args, **kwargs): model.model.metadata = metadata texts = texts.strip() texts = [[text.strip() if text.endswith('.') else (text + '.')] for text in texts.split(',')] image_ori = transform(image) with torch.no_grad(): width = image_ori.size[0] height = image_ori.size[1] image = np.asarray(image_ori) image_ori_np = np.asarray(image_ori) images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] outputs = model.model.evaluate_grounding(batch_inputs, None) visual = Visualizer(image_ori_np, metadata=metadata) grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() for idx, mask in enumerate(grd_mask): color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) res = demo.get_image() torch.cuda.empty_cache() return Image.fromarray(res), '', None