File size: 10,715 Bytes
27704e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18949ce
 
27704e0
 
 
 
 
 
 
 
 
 
 
 
 
10c79ab
27704e0
 
 
 
10c79ab
 
27704e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import os
import yaml
import argparse
import warnings
from PIL import Image
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn as nn
import torchvision.transforms as T
from transformers import logging
from diffusers import DDIMScheduler, StableDiffusionPipeline

from .ditail_utils import *

# suppress warnings
logging.set_verbosity_error()
warnings.filterwarnings("ignore", message=".*LoRA backend.*")

class DitailDemo(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if isinstance(self.args, dict):
            for k, v in args.items():
                setattr(self, k, v)
        else:
            for k, v in vars(args).items():
                setattr(self, k, v)

    def load_inv_model(self):
        self.scheduler = DDIMScheduler.from_pretrained(self.inv_model, subfolder='scheduler')
        self.scheduler.set_timesteps(self.inv_steps, device=self.device)
        print(f'[INFO] Loading inversion model: {self.inv_model}')
        pipe = StableDiffusionPipeline.from_pretrained(
            self.inv_model, torch_dtype=torch.float16
        ).to(self.device)
        pipe.enable_xformers_memory_efficient_attention()
        self.text_encoder = pipe.text_encoder
        self.tokenizer = pipe.tokenizer
        self.unet = pipe.unet
        self.vae = pipe.vae
        self.tokenizer_kwargs = dict(
            truncation=True,
            return_tensors='pt',
            padding='max_length',
            max_length=self.tokenizer.model_max_length
        )
    
    def load_spl_model(self):
        self.scheduler = DDIMScheduler.from_pretrained(self.spl_model, subfolder='scheduler')
        self.scheduler.set_timesteps(self.spl_steps, device=self.device)
        print(f'[INFO] Loading sampling model: {self.spl_model}')
        if (self.lora != 'none') or (self.inv_model != self.spl_model):
            pipe = StableDiffusionPipeline.from_pretrained(
                self.spl_model, torch_dtype=torch.float16
            ).to(self.device)
            if self.lora != 'none':
                # pipe.unfuse_lora()
                # pipe.unload_lora_weights()
                pipe.load_lora_weights(self.lora_dir, weight_name=f'{self.lora}.safetensors')
                pipe.fuse_lora(lora_scale=self.lora_scale)
            pipe.enable_xformers_memory_efficient_attention()
            self.text_encoder = pipe.text_encoder
            self.tokenizer = pipe.tokenizer
            self.unet = pipe.unet
            self.vae = pipe.vae
            self.tokenizer_kwargs = dict(
                truncation=True,
                return_tensors='pt',
                padding='max_length',
                max_length=self.tokenizer.model_max_length
            )
    
    @torch.no_grad()
    def encode_image(self, image_pil):
        # image_pil = T.Resize(512)(img.convert('RGB'))
        image_pil = T.Resize(512)(image_pil)
        width, height = image_pil.size

        image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device)
        with torch.autocast(device_type=self.device, dtype=torch.float32):
            image = 2 * image - 1
            posterior = self.vae.encode(image).latent_dist
            latent = posterior.mean * 0.18215
        return latent

    @torch.no_grad()
    def invert_image(self, cond, latent):
        self.latents = {}
        timesteps = reversed(self.scheduler.timesteps)
        with torch.autocast(device_type=self.device, dtype=torch.float32):
            for i, t in enumerate(tqdm(timesteps)):
                cond_batch = cond.repeat(latent.shape[0], 1, 1)
                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (
                    self.scheduler.alphas_cumprod[timesteps[i-1]]
                    if i > 0 else self.scheduler.final_alpha_cumprod
                )
                mu = alpha_prod_t ** 0.5
                mu_prev = alpha_prod_t_prev ** 0.5
                sigma = (1 - alpha_prod_t) ** 0.5
                sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
                eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample
                pred_x0 = (latent - sigma_prev * eps) / mu_prev
                latent = mu * pred_x0 + sigma * eps
                self.latents[t.item()] = latent
        self.noisy_latent = latent

    @torch.no_grad()
    def extract_latents(self):
        # get the embeddings for pos & neg prompts
        # self.pos_prompt = ' ,'.join(LORA_TRIGGER_WORD.get(self.lora, [''])+[self.pos_prompt])
        # print('the prompt after adding trigger word:', self.pos_prompt)
        text_pos = self.tokenizer(self.pos_prompt, **self.tokenizer_kwargs)
        text_neg = self.tokenizer(self.neg_prompt, **self.tokenizer_kwargs)
        self.emb_pos = self.text_encoder(text_pos.input_ids.to(self.device))[0]
        self.emb_neg = self.text_encoder(text_neg.input_ids.to(self.device))[0]
        # apply condition scaling
        cond = self.alpha * self.emb_pos - self.beta * self.emb_neg
        # encode source image & apply DDIM inversion
        self.invert_image(cond, self.encode_image(self.img))

    @torch.no_grad()
    def latent_to_image(self, latent, save_path=None):
        with torch.autocast(device_type=self.device, dtype=torch.float32):
            latent = 1 / 0.18215 * latent
            image = self.vae.decode(latent).sample[0]
            image = (image / 2 + 0.5).clamp(0, 1)
        # T.ToPILImage()(image).save(save_path)
        return T.ToPILImage()(image)

    def init_injection(self, attn_ratio=0.5, conv_ratio=0.8):
        attn_thresh = int(attn_ratio * self.spl_steps)
        conv_thresh = int(conv_ratio * self.spl_steps)
        self.attn_inj_timesteps = self.scheduler.timesteps[:attn_thresh]
        self.conv_inj_timesteps = self.scheduler.timesteps[:conv_thresh]
        register_attn_inj(self, self.attn_inj_timesteps)
        register_conv_inj(self, self.conv_inj_timesteps)

    @torch.no_grad()
    def sampling_loop(self):
        # init text embeddings
        text_ept = self.tokenizer('', **self.tokenizer_kwargs)
        self.emb_ept = self.text_encoder(text_ept.input_ids.to(self.device))[0]
        self.emb_spl = torch.cat([self.emb_ept, self.emb_pos, self.emb_neg], dim=0)
        with torch.autocast(device_type=self.device, dtype=torch.float16):
            # use noisy latent as starting point
            x = self.latents[self.scheduler.timesteps[0].item()]
            # sampling loop
            for t in tqdm(self.scheduler.timesteps):
                # concat latents & register timestep
                src_latent = self.latents[t.item()]
                latents = torch.cat([src_latent, x, x])
                register_time(self, t.item())
                # apply U-Net for denoising
                noise_pred = self.unet(latents, t, encoder_hidden_states=self.emb_spl).sample
                # classifier-free guidance
                _, noise_pred_pos, noise_pred_neg = noise_pred.chunk(3)
                noise_pred = noise_pred_neg + self.omega * (noise_pred_pos - noise_pred_neg)
                # denoise step
                x = self.scheduler.step(noise_pred, t, x).prev_sample
            # save output latent
            self.output_latent = x

    def run_ditail(self):
        # init output dir & dump config
        os.makedirs(self.output_dir, exist_ok=True)
        # self.save_dir = get_save_dir(self.output_dir)
        # os.makedirs(self.save_dir, exist_ok=True)
        # with open(os.path.join(self.output_dir, 'config.yaml'), 'w') as f:
        #     if isinstance(self.args, dict):
        #         f.write(yaml.dump(self.args))
        #     else:
        #         f.write(yaml.dump(vars(self.args)))
        # step 1: inversion stage
        self.load_inv_model()
        self.extract_latents()
        # self.latent_to_image(
        #     latent=self.noisy_latent,
        #     save_path=os.path.join(self.save_dir, 'noise.png')
        # )
        # step 2: sampling stage
        self.load_spl_model()
        if not self.no_injection:
            self.init_injection()
        self.sampling_loop()
        return self.latent_to_image(
            latent=self.output_latent,
            # save_path=os.path.join(self.save_dir, 'output.png')
        )

def main(args):
    seed_everything(args.seed)
    ditail = DitailDemo(args)
    ditail.run_ditail()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--output_dir', type=str, default='./output_demo')
    parser.add_argument('--inv_model', type=str, default='runwayml/stable-diffusion-v1-5',
                        help='Pre-trained inversion model name or path (step 1)')
    parser.add_argument('--spl_model', type=str, default='runwayml/stable-diffusion-v1-5',
                        help='Pre-trained sampling model name or path (step 2)')
    parser.add_argument('--inv_steps', type=int, default=50,
                        help='Number of inversion steps (step 1)')
    parser.add_argument('--spl_steps', type=int, default=50,
                        help='Number of sampling steps (step 2)')
    # parser.add_argument('--img_path', type=str, required=True,
    #                     help='Path to the source image')
    parser.add_argument('--pos_prompt', type=str, required=True,
                        help='Positive prompt for inversion')
    parser.add_argument('--neg_prompt', type=str, default='worst quality, blurry, low res, NSFW',
                        help='Negative prompt for inversion')
    parser.add_argument('--alpha', type=float, default=2.0,
                        help='Positive prompt scaling factor')
    parser.add_argument('--beta', type=float, default=1.0,
                        help='Negative prompt scaling factor')
    parser.add_argument('--omega', type=float, default=15,
                        help='Classifier-free guidance factor')
    parser.add_argument('--mask', type=str, default='none',
                        help='Optional mask for regional injection')
    parser.add_argument('--lora', type=str, default='none',
                        help='Optional LoRA for the sampling stage')
    parser.add_argument('--lora_dir', type=str, default='./lora',
                        help='Optional LoRA storing directory')
    parser.add_argument('--lora_scale', type=float, default=0.7,
                        help='Optional LoRA scaling weight')
    parser.add_argument('--no_injection', action="store_true",
                        help='Do not use PnP injection')
    args = parser.parse_args()
    main(args)