# -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang (jianwyan@microsoft.com), Xueyan Zou (xueyan@cs.wisc.edu) # -------------------------------------------------------- import torch import numpy as np from PIL import Image from utils.inpainting import pad_image from torchvision import transforms from utils.visualizer import Visualizer from diffusers import StableDiffusionInpaintPipeline from detectron2.utils.colormap import random_color from detectron2.data import MetadataCatalog from scipy import ndimage t = [] t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) transform = transforms.Compose(t) metadata = MetadataCatalog.get('ade20k_panoptic_train') pipe = StableDiffusionInpaintPipeline.from_pretrained( # "stabilityai/stable-diffusion-2-inpainting", "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, ).to("cuda") def crop_image(input_image): crop_w, crop_h = np.floor(np.array(input_image.size) / 64).astype(int) * 64 im_cropped = Image.fromarray(np.array(input_image)[:crop_h, :crop_w]) return im_cropped def referring_inpainting(model, image, texts, inpainting_text, *args, **kwargs): model.model.metadata = metadata texts = [[texts if texts.strip().endswith('.') else (texts.strip() + '.')]] image_ori = transform(image) with torch.no_grad(): width = image_ori.size[0] height = image_ori.size[1] image = np.asarray(image_ori) image_ori_np = np.asarray(image_ori) images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] outputs = model.model.evaluate_grounding(batch_inputs, None) visual = Visualizer(image_ori_np, metadata=metadata) grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() for idx, mask in enumerate(grd_mask): color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) res = demo.get_image() if inpainting_text not in ['no', '']: # if we want to do inpainting image_crop = crop_image(image_ori.convert('RGB')) struct2 = ndimage.generate_binary_structure(2, 2) mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype) mask = crop_image(Image.fromarray(mask_dilated * 255).convert('RGB')) image_and_mask = { "image": image_crop, "mask": mask, } width = image_crop.size[0]; height = image_crop.size[1] images_inpainting = pipe(prompt = inpainting_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images[0] # put images_inpainting back to original image image_ori.paste(images_inpainting) torch.cuda.empty_cache() return Image.fromarray(res) ,'' , image_ori else: torch.cuda.empty_cache() return image_ori, 'text', Image.fromarray(res)