import torch from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL import tqdm def get_subject_idx(model,prompt,src_subject,device): tokenized_prompt = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",).to(device) input_ids = tokenized_prompt['input_ids'] src_subject_idxs = [] for subject,input_id in zip(src_subject,input_ids): src_subject_token_id = [model.tokenizer.encode(i, add_special_tokens=False)[0] for i in subject.split(' ')] src_subject_idxs = [i for i, x in enumerate(input_id.tolist()) if x in src_subject_token_id] return [src_subject_idxs] def add_function(model): @torch.no_grad() def generate_with_adapters( model, prompt_embeds, num_inference_steps, generator, t_range=list(range(0,950)), ): latents = model.prepare_latents(prompt_embeds.shape[0]//2,4,512,512,prompt_embeds.dtype,prompt_embeds.device,generator) model.scheduler.set_timesteps(num_inference_steps) iterator = tqdm.tqdm(model.scheduler.timesteps) mask_ig_prev = None for i, t in enumerate(iterator): if not t in t_range: model.moMA_generator.toggle_enable_flag('cross') else: model.moMA_generator.toggle_enable_flag('all') latent_model_input = torch.cat([latents] * 2) noise_pred = model.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False, )[0] # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond) latents = model.scheduler.step(noise_pred, t, latents, return_dict=False)[0] mask_ig_prev = (get_mask_from_cross(model.unet.attn_processors))[latents.shape[0]:] model.moMA_generator.set_self_mask('self','ig',mask_ig_prev) model.moMA_generator.set_self_mask('cross',mask=mask_ig_prev.clone().detach()) image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0] return image ,mask_ig_prev.repeat(1,3,1,1) if (not mask_ig_prev==None) else None model.generate_with_adapters = generate_with_adapters class ImageProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class MoMA_generator: def __init__(self, device,args): self.args = args self.device = device noise_scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,) print('Loading VAE: stabilityai--sd-vae-ft-mse...') vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") print('Loading StableDiffusion: Realistic_Vision...') self.pipe = StableDiffusionPipeline.from_pretrained( "SG161222/Realistic_Vision_V4.0_noVAE", torch_dtype=torch.float16, scheduler=noise_scheduler, vae=vae, feature_extractor=None, safety_checker=None, ).to(self.device) self.unet = self.pipe.unet add_function(self.pipe) self.pipe.moMA_generator = self self.set_ip_adapter() self.image_proj_model = self.init_proj() def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, ).to(self.device, dtype=torch.float16) return image_proj_model def set_ip_adapter(self): unet = self.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16) else: attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) @torch.inference_mode() def get_image_embeds_CFG(self, llava_emb): clip_image_embeds = llava_emb image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds def get_image_crossAttn_feature( self, llava_emb, num_samples=1, ): image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_CFG(llava_emb) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) return image_prompt_embeds, uncond_image_prompt_embeds # feature are from self-attention layers of Unet: feed reference image to Unet with t=0 def get_image_selfAttn_feature( self, pil_image, prompt, ): self.toggle_enable_flag('self') self.toggle_extract_inject_flag('self', 'extract') tokenized_prompt = self.pipe.tokenizer(prompt,padding="max_length",truncation=True,return_tensors="pt",).to(self.device) text_embeddings = self.pipe.text_encoder(input_ids=tokenized_prompt.input_ids)[0] ref_image = pil_image ref_image.to(self.device) with torch.no_grad(): latents = self.pipe.vae.encode(ref_image).latent_dist.sample() latents = latents * self.pipe.vae.config.scaling_factor noise = torch.randn_like(latents) timesteps = torch.tensor([0],device=latents.device).long() # fixed to 0 noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timesteps) _ = self.unet(noisy_latents,timestep=timesteps,encoder_hidden_states=text_embeddings)["sample"] # features are stored in attn_processors return None @torch.no_grad() def generate_with_MoMA( self, batch, llava_emb=None, seed=None, device='cuda', ): self.reset_all() img_ig,mask_id,subject,prompt = batch['image'].half().to(device),batch['mask'].half().to(device),batch['label'][0],batch['text'][0] prompt = [f"photo of a {subject}. "+ prompt] subject_idx = get_subject_idx(self.pipe,prompt,[subject],self.device) negative_prompt = None # get context-cross-attention feature (from MLLM decoder) cond_llava_embeds, uncond_llava_embeds = self.get_image_crossAttn_feature(llava_emb,num_samples=1) # get subject-cross-attention feature (from Unet) self.get_image_selfAttn_feature(img_ig,subject) # features are stored in attn_processors with torch.inference_mode(): prompt_embeds = self.pipe._encode_prompt( prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt) negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, cond_llava_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_llava_embeds], dim=1) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None self.set_self_mask('eraseAll') self.toggle_enable_flag('all') self.toggle_extract_inject_flag('all','masked_generation') self.set_self_mask('self','id',mask_id) self.set_cross_subject_idxs(subject_idx) images, mask = self.pipe.generate_with_adapters( self.pipe, prompt_embeds, 50, generator, ) images = torch.clip((images+1)/2.0,min=0.0,max=1.0) return images.cpu(), mask.cpu() def set_selfAttn_strength(self, strength): for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = 1.0 if isinstance(attn_processor, IPAttnProcessor_Self): attn_processor.scale = strength def set_cross_subject_idxs(self, subject_idxs): for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.subject_idxs = subject_idxs def set_self_mask(self,mode,id_ig='', mask=None): #only have effect on self attn of the generation process for attn_processor in self.unet.attn_processors.values(): if mode == 'eraseAll': if isinstance(attn_processor, IPAttnProcessor_Self): attn_processor.mask_id,attn_processor.mask_ig = None,None if isinstance(attn_processor, IPAttnProcessor): attn_processor.mask_i, attn_processor.mask_ig_prev = None, None if mode == 'self': if isinstance(attn_processor, IPAttnProcessor_Self): if id_ig == 'id':attn_processor.mask_id = mask if id_ig == 'ig':attn_processor.mask_ig = mask if mode == 'cross': if isinstance(attn_processor, IPAttnProcessor): attn_processor.mask_ig_prev = mask def toggle_enable_flag(self, processor_enable_mode): for attn_processor in self.unet.attn_processors.values(): if processor_enable_mode == 'cross': if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = True if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = False if processor_enable_mode == 'self': if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = False if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = True if processor_enable_mode == 'all': attn_processor.enabled = True if processor_enable_mode == 'none': attn_processor.enabled = False def toggle_extract_inject_flag(self, processor_name, mode): # mode: str, 'extract' or 'inject' or 'both'(cross only) for attn_processor in self.unet.attn_processors.values(): if processor_name == 'cross': if isinstance(attn_processor, IPAttnProcessor):attn_processor.mode = mode if processor_name == 'self': if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.mode = mode if processor_name == 'all': attn_processor.mode = mode def reset_all(self,keep_self=False): for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.store_attn, attn_processor.subject_idxs, attn_processor.mask_i, attn_processor.mask_ig_prev, self.subject_idxs = None, None, None, None, None if isinstance(attn_processor, IPAttnProcessor_Self): attn_processor.mask_id, attn_processor.mask_ig = None, None if not keep_self: attn_processor.store_ks, attn_processor.store_vs = [], []