File size: 13,405 Bytes
73c6f92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import torch
from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
import tqdm


def get_subject_idx(model,prompt,src_subject,device):
    tokenized_prompt = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",).to(device)
    input_ids = tokenized_prompt['input_ids']
    src_subject_idxs = []
    for subject,input_id in zip(src_subject,input_ids):
        src_subject_token_id = [model.tokenizer.encode(i, add_special_tokens=False)[0] for i in subject.split(' ')]
        src_subject_idxs = [i for i, x in enumerate(input_id.tolist()) if x in src_subject_token_id]
    return [src_subject_idxs]


def add_function(model):
    @torch.no_grad()
    def generate_with_adapters(
        model,
        prompt_embeds,
        num_inference_steps,
        generator,
        t_range=list(range(0,950)),
    ):
        
        latents = model.prepare_latents(prompt_embeds.shape[0]//2,4,512,512,prompt_embeds.dtype,prompt_embeds.device,generator)

        model.scheduler.set_timesteps(num_inference_steps)

        iterator = tqdm.tqdm(model.scheduler.timesteps)
        mask_ig_prev = None
        for i, t in enumerate(iterator):
            if not t in t_range: 
                model.moMA_generator.toggle_enable_flag('cross')
            else:
                model.moMA_generator.toggle_enable_flag('all')

            latent_model_input = torch.cat([latents] * 2)
            noise_pred = model.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                return_dict=False,
            )[0]

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)

            latents = model.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            mask_ig_prev = (get_mask_from_cross(model.unet.attn_processors))[latents.shape[0]:]

            model.moMA_generator.set_self_mask('self','ig',mask_ig_prev)
            model.moMA_generator.set_self_mask('cross',mask=mask_ig_prev.clone().detach())

        image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0]
        return image ,mask_ig_prev.repeat(1,3,1,1) if (not mask_ig_prev==None) else None
    model.generate_with_adapters = generate_with_adapters


class ImageProjModel(torch.nn.Module):
    """Projection Model"""
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class MoMA_generator:
    def __init__(self, device,args):
        self.args = args
        self.device = device
        
        noise_scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,)
        
        print('Loading VAE: stabilityai--sd-vae-ft-mse...')
        vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        
        print('Loading StableDiffusion: Realistic_Vision...')
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "SG161222/Realistic_Vision_V4.0_noVAE",
            torch_dtype=torch.float16,
            scheduler=noise_scheduler,
            vae=vae,
            feature_extractor=None,
            safety_checker=None,
        ).to(self.device)

        self.unet = self.pipe.unet
        add_function(self.pipe)
        self.pipe.moMA_generator = self

        self.set_ip_adapter()
        self.image_proj_model = self.init_proj()

    def init_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=768,
            clip_embeddings_dim=1024,
            clip_extra_context_tokens=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model
        
    def set_ip_adapter(self):
        unet = self.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
            else:
                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
        unet.set_attn_processor(attn_procs)

    @torch.inference_mode()
    def get_image_embeds_CFG(self, llava_emb):
        clip_image_embeds = llava_emb
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds
    
    def get_image_crossAttn_feature(
            self,
            llava_emb,
            num_samples=1,
    ):
        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_CFG(llava_emb)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        return image_prompt_embeds, uncond_image_prompt_embeds

    # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
    def get_image_selfAttn_feature(
            self,
            pil_image,
            prompt,
    ):  
        self.toggle_enable_flag('self')
        self.toggle_extract_inject_flag('self', 'extract')
        tokenized_prompt = self.pipe.tokenizer(prompt,padding="max_length",truncation=True,return_tensors="pt",).to(self.device)
        text_embeddings = self.pipe.text_encoder(input_ids=tokenized_prompt.input_ids)[0]

        ref_image = pil_image
        ref_image.to(self.device)

        with torch.no_grad(): latents = self.pipe.vae.encode(ref_image).latent_dist.sample()
        latents = latents * self.pipe.vae.config.scaling_factor

        noise = torch.randn_like(latents)
        timesteps = torch.tensor([0],device=latents.device).long() # fixed to 0
        noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timesteps)
        
        _ = self.unet(noisy_latents,timestep=timesteps,encoder_hidden_states=text_embeddings)["sample"]
        # features are stored in attn_processors

        return None
    
    @torch.no_grad()
    def generate_with_MoMA(
        self,
        batch,
        llava_emb=None,
        seed=None,
        device='cuda',
    ):
        self.reset_all()
        img_ig,mask_id,subject,prompt = batch['image'].half().to(device),batch['mask'].half().to(device),batch['label'][0],batch['text'][0]

        prompt = [f"photo of a {subject}. "+ prompt]
        subject_idx = get_subject_idx(self.pipe,prompt,[subject],self.device)
        negative_prompt = None 
            
        # get context-cross-attention feature (from MLLM decoder)
        cond_llava_embeds, uncond_llava_embeds = self.get_image_crossAttn_feature(llava_emb,num_samples=1)
        # get subject-cross-attention feature (from Unet)
        self.get_image_selfAttn_feature(img_ig,subject) # features are stored in attn_processors

        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, cond_llava_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_llava_embeds], dim=1)
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        
        self.set_self_mask('eraseAll')
        self.toggle_enable_flag('all')
        self.toggle_extract_inject_flag('all','masked_generation')
        self.set_self_mask('self','id',mask_id) 
        self.set_cross_subject_idxs(subject_idx)
        
        images, mask = self.pipe.generate_with_adapters(
            self.pipe,
            prompt_embeds,
            50,
            generator,
        )
        images = torch.clip((images+1)/2.0,min=0.0,max=1.0)

        return images.cpu(), mask.cpu()
    
    def set_selfAttn_strength(self, strength):
        for attn_processor in self.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = 1.0
            if isinstance(attn_processor, IPAttnProcessor_Self):
                attn_processor.scale = strength

    def set_cross_subject_idxs(self, subject_idxs):
        for attn_processor in self.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.subject_idxs = subject_idxs

    def set_self_mask(self,mode,id_ig='', mask=None): #only have effect on self attn of the generation process
        for attn_processor in self.unet.attn_processors.values():
            if mode == 'eraseAll':
                if isinstance(attn_processor, IPAttnProcessor_Self):
                    attn_processor.mask_id,attn_processor.mask_ig = None,None
                if isinstance(attn_processor, IPAttnProcessor):
                    attn_processor.mask_i, attn_processor.mask_ig_prev = None, None
            if mode == 'self':
                if isinstance(attn_processor, IPAttnProcessor_Self):
                    if id_ig == 'id':attn_processor.mask_id = mask
                    if id_ig == 'ig':attn_processor.mask_ig = mask
            if mode == 'cross':
                if isinstance(attn_processor, IPAttnProcessor):
                    attn_processor.mask_ig_prev = mask
    
    def toggle_enable_flag(self, processor_enable_mode):
        for attn_processor in self.unet.attn_processors.values():
            if processor_enable_mode == 'cross':
                if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = True
                if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = False
            if processor_enable_mode == 'self':
                if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = False
                if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = True
            if processor_enable_mode == 'all':
                attn_processor.enabled = True
            if processor_enable_mode == 'none':
                attn_processor.enabled = False

    def toggle_extract_inject_flag(self, processor_name, mode): # mode: str, 'extract' or 'inject' or 'both'(cross only)
        for attn_processor in self.unet.attn_processors.values():
            if processor_name == 'cross':
                if isinstance(attn_processor, IPAttnProcessor):attn_processor.mode = mode
            if processor_name == 'self':
                if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.mode = mode
            if processor_name == 'all':
                attn_processor.mode = mode

    def reset_all(self,keep_self=False):
        for attn_processor in self.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.store_attn, attn_processor.subject_idxs, attn_processor.mask_i, attn_processor.mask_ig_prev, self.subject_idxs = None, None, None, None, None

            if isinstance(attn_processor, IPAttnProcessor_Self):
                attn_processor.mask_id, attn_processor.mask_ig = None, None
                if not keep_self: attn_processor.store_ks, attn_processor.store_vs = [], []