# -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- import os import openai import torch import numpy as np from scipy import ndimage from PIL import Image from utils.inpainting import pad_image, crop_image from torchvision import transforms from utils.visualizer import Visualizer from diffusers import StableDiffusionInpaintPipeline from detectron2.utils.colormap import random_color from detectron2.data import MetadataCatalog t = [] t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) transform = transforms.Compose(t) metadata = MetadataCatalog.get('ade20k_panoptic_train') pipe = StableDiffusionInpaintPipeline.from_pretrained( # "stabilityai/stable-diffusion-2-inpainting", "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, ).to("cuda") prompts = [] prompts.append("remove the person, task: (referring editing), source: [person], target:;") prompts.append("remove the person in the middle, task: (referring editing), source: [person in the middle], target:;") prompts.append("remove the dog on the left side, task: (referring editing), source: [dog on the left side], target:;") prompts.append("change the apple to a pear, task: (referring editing), source: [apple], target: ;") prompts.append("change the red apple to a green one, task: (referring editing), source: [red apple], target: ;") prompts.append("replace the dog with a cat, task: (referring editing), source: [dot], target: ;") prompts.append("replace the red apple with a green one, task: (referring editing), source: [red apple], target: ;") def get_gpt3_response(prompt): openai.api_key = os.getenv("OPENAI_API_KEY") response = openai.Completion.create( model="text-davinci-003", prompt=prompt, temperature=0.7, max_tokens=128, top_p=1, frequency_penalty=0, presence_penalty=0 ) return response def referring_inpainting_gpt3(model, image, instruction, *args, **kwargs): # convert instruction to source and target print(instruction) resp = get_gpt3_response(' '.join(prompts) + instruction + ',') resp_text = resp['choices'][0]['text'] print(resp_text) ref_text = resp_text[resp_text.find('[')+1:resp_text.find(']')] inp_text = resp_text[resp_text.find('<')+1:resp_text.find('>')] model.model.metadata = metadata texts = [[ref_text if ref_text.strip().endswith('.') else (ref_text.strip() + '.')]] image_ori = crop_image(transform(image)) with torch.no_grad(): width = image_ori.size[0] height = image_ori.size[1] image = np.asarray(image_ori) image_ori_np = np.asarray(image_ori) images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] outputs = model.model.evaluate_grounding(batch_inputs, None) visual = Visualizer(image_ori_np, metadata=metadata) grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() for idx, mask in enumerate(grd_mask): color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) res = demo.get_image() if inp_text not in ['no', '']: image_crop = image_ori struct2 = ndimage.generate_binary_structure(2, 2) mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype) mask = Image.fromarray(mask_dilated * 255).convert('RGB') image_and_mask = { "image": image_crop, "mask": mask, } # images_inpainting = inpainting(inpainting_model, image_and_mask, inp_text, ddim_steps, num_samples, scale, seed) width = image_ori.size[0]; height = image_ori.size[1] images_inpainting = pipe(prompt = inp_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images torch.cuda.empty_cache() return Image.fromarray(res), resp_text, images_inpainting[0] else: torch.cuda.empty_cache() return image_ori, resp_text, Image.fromarray(res)