import os import sys import cv2 import numpy as np import torch from PIL import Image from .utils import gen_new_name, prompts import torch from omegaconf import OmegaConf import numpy as np import wget from .inpainting_src.ldm_inpainting.ldm.models.diffusion.ddim import DDIMSampler from .inpainting_src.ldm_inpainting.ldm.util import instantiate_from_config from .utils import cal_dilate_factor, dilate_mask def make_batch(image, mask, device): image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) batch[k] = batch[k] * 2.0 - 1.0 return batch class LDMInpainting: def __init__(self, device): self.model_checkpoint_path = 'model_zoo/ldm_inpainting_big.ckpt' config = './iGPT/models/inpainting_src/ldm_inpainting/config.yaml' self.ddim_steps = 50 self.device = device config = OmegaConf.load(config) model = instantiate_from_config(config.model) self.download_parameters() model.load_state_dict(torch.load(self.model_checkpoint_path)["state_dict"], strict=False) self.model = model.to(device=device) self.sampler = DDIMSampler(model) def download_parameters(self): url = 'https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1' if not os.path.exists(self.model_checkpoint_path): wget.download(url, out=self.model_checkpoint_path) @prompts(name="Remove the Masked Object", description="useful when you want to remove an object by masking the region in the image. " "like: remove masked object or inpaint the masked region.. " "The input to this tool should be a comma separated string of two, " "representing the image_path and mask_path") @torch.no_grad() def inference(self, inputs): print(f'inputs: {inputs}') # image, mask, device img_path, mask_path = inputs.split(',')[0], inputs.split(',')[1] img_path = img_path.strip() mask_path = mask_path.strip() image = Image.open(img_path) mask = Image.open(mask_path).convert('L') w, h = image.size image = image.resize((512, 512)) mask = mask.resize((512, 512)) image = np.array(image) mask = np.array(mask) dilate_factor = cal_dilate_factor(mask.astype(np.uint8)) mask = dilate_mask(mask, dilate_factor) with self.model.ema_scope(): batch = make_batch(image, mask, device=self.device) # encode masked image and concat downsampled mask c = self.model.cond_stage_model.encode(batch["masked_image"]) cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) shape = (c.shape[1] - 1,) + c.shape[2:] samples_ddim, _ = self.sampler.sample(S=self.ddim_steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False) x_samples_ddim = self.model.decode_first_stage(samples_ddim) image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) inpainted = (1 - mask) * image + mask * predicted_image inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 # print(type(inpainted)) inpainted = inpainted.astype(np.uint8) new_img_name = gen_new_name(img_path, 'LDMInpainter') new_img = Image.fromarray(inpainted) new_img = new_img.resize((w, h)) new_img.save(new_img_name) print( f"\nProcessed LDMInpainting, Inputs: {inputs}, " f"Output Image: {new_img_name}") return new_img_name # return inpainted ''' if __name__ == '__main__': painting = LDMInpainting('cuda:0') res = painting.inference(f'image/82e612_fe54ca_raw.png,image/04a785_fe54ca_mask.png.') print(res) '''