File size: 4,855 Bytes
0f90f73
 
 
 
 
 
 
 
 
 
1d63199
0f90f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d63199
ee25e9d
0f90f73
 
 
 
1d63199
 
0f90f73
 
 
1d63199
 
 
 
 
0f90f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys
import cv2
import numpy as np
import torch
from PIL import Image
from .utils import gen_new_name, prompts
import torch
from omegaconf import OmegaConf
import numpy as np
import wget
from .inpainting_src.ldm_inpainting.ldm.models.diffusion.ddim import DDIMSampler
from .inpainting_src.ldm_inpainting.ldm.util import instantiate_from_config
from .utils import cal_dilate_factor, dilate_mask


def make_batch(image, mask, device):
    image = image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)

    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)
        
    masked_image = (1 - mask) * image

    batch = {"image": image, "mask": mask, "masked_image": masked_image}
    for k in batch:
        batch[k] = batch[k].to(device=device)
        batch[k] = batch[k] * 2.0 - 1.0
    return batch


class LDMInpainting:
    def __init__(self, device):
        self.model_checkpoint_path = 'model_zoo/ldm_inpainting_big.ckpt'
        config = './iGPT/models/inpainting_src/ldm_inpainting/config.yaml'
        self.ddim_steps = 50
        self.device = device
        config = OmegaConf.load(config)
        model = instantiate_from_config(config.model)
        self.download_parameters()
        model.load_state_dict(torch.load(self.model_checkpoint_path)["state_dict"], strict=False)
        self.model = model.to(device=device)
        self.sampler = DDIMSampler(model)
    
    def download_parameters(self):
        url = 'https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1'
        if not os.path.exists(self.model_checkpoint_path):
            wget.download(url, out=self.model_checkpoint_path)

    @prompts(name="Remove the Masked Object",
             description="useful when you want to remove an object by masking the region in the image. "
                         "like: remove masked object or inpaint the masked region.. "
                         "The input to this tool should be a comma separated string of two, "
                         "representing the image_path and mask_path")
    @torch.no_grad()
    def inference(self, inputs):
        print(f'inputs: {inputs}')
        # image, mask, device
        img_path, mask_path = inputs.split(',')[0], inputs.split(',')[1]
        img_path = img_path.strip()
        mask_path = mask_path.strip()
        image = Image.open(img_path)
        mask = Image.open(mask_path).convert('L')
        w, h = image.size
        image = image.resize((512, 512))
        mask = mask.resize((512, 512))
        image = np.array(image)
        mask = np.array(mask)
        dilate_factor = cal_dilate_factor(mask.astype(np.uint8))
        mask = dilate_mask(mask, dilate_factor)
        
        with self.model.ema_scope():
            batch = make_batch(image, mask, device=self.device)
            # encode masked image and concat downsampled mask
            c = self.model.cond_stage_model.encode(batch["masked_image"])
            cc = torch.nn.functional.interpolate(batch["mask"],
                                                 size=c.shape[-2:])
            c = torch.cat((c, cc), dim=1)

            shape = (c.shape[1] - 1,) + c.shape[2:]
            samples_ddim, _ = self.sampler.sample(S=self.ddim_steps,
                                                    conditioning=c,
                                                    batch_size=c.shape[0],
                                                    shape=shape,
                                                    verbose=False)
            x_samples_ddim = self.model.decode_first_stage(samples_ddim)

            image = torch.clamp((batch["image"] + 1.0) / 2.0,
                                min=0.0, max=1.0)
            mask = torch.clamp((batch["mask"] + 1.0) / 2.0,
                               min=0.0, max=1.0)
            predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0,
                                          min=0.0, max=1.0)

            inpainted = (1 - mask) * image + mask * predicted_image
            inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
        
        # print(type(inpainted))
        inpainted = inpainted.astype(np.uint8)
        new_img_name = gen_new_name(img_path, 'LDMInpainter')
        new_img = Image.fromarray(inpainted)
        new_img = new_img.resize((w, h))
        new_img.save(new_img_name)
        print(
            f"\nProcessed LDMInpainting, Inputs: {inputs}, "
            f"Output Image: {new_img_name}")
        return new_img_name
        # return inpainted

'''
if __name__ == '__main__':
    painting = LDMInpainting('cuda:0')
    res = painting.inference(f'image/82e612_fe54ca_raw.png,image/04a785_fe54ca_mask.png.')
    print(res)
'''