KatharinaK commited on
Commit
2a869f2
1 Parent(s): d223295

Added attention masking and intersect masking; fix truncation of prompts

Browse files
modified_pipeline_semantic_stable_diffusion.py CHANGED
@@ -9,16 +9,180 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
 
10
  from diffusers.image_processor import VaeImageProcessor
11
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
 
12
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
13
  from diffusers.schedulers import KarrasDiffusionSchedulers
14
- from diffusers.utils import logging, randn_tensor
 
15
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
17
  # from . import SemanticStableDiffusionPipelineOutput
18
 
 
 
19
 
20
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class SemanticStableDiffusionPipeline(DiffusionPipeline):
24
  r"""
@@ -207,6 +371,29 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
207
  latents = latents * self.scheduler.init_noise_sigma
208
  return latents
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  @torch.no_grad()
211
  def __call__(
212
  self,
@@ -235,7 +422,13 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
235
  edit_mom_beta: Optional[float] = 0.4,
236
  edit_weights: Optional[List[float]] = None,
237
  sem_guidance: Optional[List[torch.Tensor]] = None,
238
-
 
 
 
 
 
 
239
  # DDPM additions
240
  use_ddpm: bool = False,
241
  wts: Optional[List[torch.Tensor]] = None,
@@ -334,6 +527,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
334
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
335
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
336
  """
 
 
 
 
 
 
337
  # 0. Default height and width to unet
338
  height = height or self.unet.config.sample_size * self.vae_scale_factor
339
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -348,12 +547,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
348
  enable_edit_guidance = True
349
  if isinstance(editing_prompt, str):
350
  editing_prompt = [editing_prompt]
351
- enabled_editing_prompts = len(editing_prompt)
352
  elif editing_prompt_embeddings is not None:
353
  enable_edit_guidance = True
354
- enabled_editing_prompts = editing_prompt_embeddings.shape[0]
355
  else:
356
- enabled_editing_prompts = 0
357
  enable_edit_guidance = False
358
 
359
  # get prompt text embeddings
@@ -361,17 +560,23 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
361
  prompt,
362
  padding="max_length",
363
  max_length=self.tokenizer.model_max_length,
 
364
  return_tensors="pt",
365
  )
366
  text_input_ids = text_inputs.input_ids
 
367
 
368
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
369
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
 
 
 
 
370
  logger.warning(
371
  "The following part of your input was truncated because CLIP can only handle sequences up to"
372
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
373
  )
374
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
375
  text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
376
 
377
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -382,24 +587,37 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
382
  if enable_edit_guidance:
383
  # get safety text embeddings
384
  if editing_prompt_embeddings is None:
 
 
 
 
385
  edit_concepts_input = self.tokenizer(
386
  [x for item in editing_prompt for x in repeat(item, batch_size)],
387
  padding="max_length",
388
  max_length=self.tokenizer.model_max_length,
 
389
  return_tensors="pt",
 
390
  )
391
 
 
392
  edit_concepts_input_ids = edit_concepts_input.input_ids
 
 
 
 
393
 
394
- if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length:
 
 
395
  removed_text = self.tokenizer.batch_decode(
396
- edit_concepts_input_ids[:, self.tokenizer.model_max_length :]
397
  )
398
  logger.warning(
399
  "The following part of your input was truncated because CLIP can only handle sequences up to"
400
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
401
  )
402
- edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
403
  edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
404
  else:
405
  edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
@@ -453,8 +671,11 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
453
  # For classifier free guidance, we need to do two forward passes.
454
  # Here we concatenate the unconditional and text embeddings into a single batch
455
  # to avoid doing two forward passes
 
456
  if enable_edit_guidance:
457
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
 
 
458
  else:
459
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
460
  # get the initial random noise unless the user supplied it
@@ -466,6 +687,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
466
  t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
467
  timesteps = timesteps[-zs.shape[0]:]
468
 
 
 
 
469
  # 5. Prepare latent variables
470
  num_channels_latents = self.unet.config.in_channels
471
  latents = self.prepare_latents(
@@ -493,7 +717,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
493
  for i, t in enumerate(self.progress_bar(timesteps)):
494
  # expand the latents if we are doing classifier free guidance
495
  latent_model_input = (
496
- torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents
497
  )
498
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
499
 
@@ -502,7 +726,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
502
 
503
  # perform guidance
504
  if do_classifier_free_guidance:
505
- noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64]
506
  noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
507
  noise_pred_edit_concepts = noise_pred_out[2:]
508
 
@@ -589,27 +813,115 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
589
 
590
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
591
 
592
- # torch.quantile function expects float32
593
- if noise_guidance_edit_tmp.dtype == torch.float32:
594
- tmp = torch.quantile(
595
- torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),
596
- edit_threshold_c,
597
- dim=2,
598
- keepdim=False,
 
599
  )
600
- else:
601
- tmp = torch.quantile(
602
- torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),
603
- edit_threshold_c,
604
- dim=2,
605
- keepdim=False,
606
- ).to(noise_guidance_edit_tmp.dtype)
607
-
608
- noise_guidance_edit_tmp = torch.where(
609
- torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None],
610
- noise_guidance_edit_tmp,
611
- torch.zeros_like(noise_guidance_edit_tmp),
612
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
614
 
615
  # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
@@ -712,6 +1024,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
712
  else: #if not use_ddpm:
713
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
714
 
 
 
 
 
 
 
715
  # call the callback, if provided
716
  if callback is not None and i % callback_steps == 0:
717
  callback(i, t, latents)
 
9
 
10
  from diffusers.image_processor import VaeImageProcessor
11
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.models.attention_processor import AttnProcessor, Attention
13
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
  from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.utils import logging
16
+ from diffusers.utils.torch_utils import randn_tensor
17
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
19
  # from . import SemanticStableDiffusionPipelineOutput
20
 
21
+ import torch.nn.functional as F
22
+ import math
23
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
26
+ class AttentionStore():
27
+ @staticmethod
28
+ def get_empty_store():
29
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
30
+ "down_self": [], "mid_self": [], "up_self": []}
31
+
32
+ def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP):
33
+ # attn.shape = batch_size * head_size, seq_len query, seq_len_key
34
+ bs = 2 + int(PnP) + editing_prompts
35
+ source_batch_size = int(attn.shape[0] // bs)
36
+ skip = 2 if PnP else 1 # skip PnP & unconditional
37
+ self.forward(
38
+ attn[skip*source_batch_size:],
39
+ is_cross,
40
+ place_in_unet)
41
+
42
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
43
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
44
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
45
+ self.step_store[key].append(attn)
46
+
47
+ def between_steps(self, store_step=True):
48
+ if store_step:
49
+ if self.average:
50
+ if len(self.attention_store) == 0:
51
+ self.attention_store = self.step_store
52
+ else:
53
+ for key in self.attention_store:
54
+ for i in range(len(self.attention_store[key])):
55
+ self.attention_store[key][i] += self.step_store[key][i]
56
+ else:
57
+ if len(self.attention_store) == 0:
58
+ self.attention_store = [self.step_store]
59
+ else:
60
+ self.attention_store.append(self.step_store)
61
+
62
+ self.cur_step += 1
63
+ self.step_store = self.get_empty_store()
64
+
65
+ def get_attention(self, step: int):
66
+ if self.average:
67
+ attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
68
+ else:
69
+ assert(step is not None)
70
+ attention = self.attention_store[step]
71
+ return attention
72
+
73
+ def aggregate_attention(self, attention_maps, prompts, res: int,
74
+ from_where: List[str], is_cross: bool, select: int
75
+ ):
76
+ out = []
77
+ num_pixels = res ** 2
78
+ for location in from_where:
79
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
80
+ if item.shape[1] == num_pixels:
81
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
82
+ out.append(cross_maps)
83
+ out = torch.cat(out, dim=0)
84
+ # average over heads
85
+ out = out.sum(0) / out.shape[0]
86
+ return out
87
+
88
+ def __init__(self, average: bool):
89
+ self.step_store = self.get_empty_store()
90
+ self.attention_store = []
91
+ self.cur_step = 0
92
+ self.average = average
93
+
94
+ class CrossAttnProcessor:
95
+
96
+ def __init__(self, attention_store, place_in_unet, PnP, editing_prompts):
97
+ self.attnstore = attention_store
98
+ self.place_in_unet = place_in_unet
99
+ self.editing_prompts = editing_prompts
100
+ self.PnP = PnP
101
+
102
+ def __call__(
103
+ self,
104
+ attn: Attention,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ attention_mask=None,
108
+ temb=None,
109
+ ):
110
+ assert(not attn.residual_connection)
111
+ assert(attn.spatial_norm is None)
112
+ assert(attn.group_norm is None)
113
+ assert(hidden_states.ndim != 4)
114
+ assert(encoder_hidden_states is not None) # is cross
115
+
116
+ batch_size, sequence_length, _ = (
117
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
118
+ )
119
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120
+
121
+ query = attn.to_q(hidden_states)
122
+
123
+ if encoder_hidden_states is None:
124
+ encoder_hidden_states = hidden_states
125
+ elif attn.norm_cross:
126
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
127
+
128
+ key = attn.to_k(encoder_hidden_states)
129
+ value = attn.to_v(encoder_hidden_states)
130
+
131
+ query = attn.head_to_batch_dim(query)
132
+ key = attn.head_to_batch_dim(key)
133
+ value = attn.head_to_batch_dim(value)
134
+
135
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
136
+ self.attnstore(attention_probs,
137
+ is_cross=True,
138
+ place_in_unet=self.place_in_unet,
139
+ editing_prompts=self.editing_prompts,
140
+ PnP=self.PnP)
141
+
142
+ hidden_states = torch.bmm(attention_probs, value)
143
+ hidden_states = attn.batch_to_head_dim(hidden_states)
144
+
145
+ # linear proj
146
+ hidden_states = attn.to_out[0](hidden_states)
147
+ # dropout
148
+ hidden_states = attn.to_out[1](hidden_states)
149
+
150
+ hidden_states = hidden_states / attn.rescale_output_factor
151
+ return hidden_states
152
+
153
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing
154
+ class GaussianSmoothing():
155
+
156
+ def __init__(self, device):
157
+ kernel_size = [3, 3]
158
+ sigma = [0.5, 0.5]
159
+
160
+ # The gaussian kernel is the product of the gaussian function of each dimension.
161
+ kernel = 1
162
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
163
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
164
+ mean = (size - 1) / 2
165
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
166
+
167
+ # Make sure sum of values in gaussian kernel equals 1.
168
+ kernel = kernel / torch.sum(kernel)
169
+
170
+ # Reshape to depthwise convolutional weight
171
+ kernel = kernel.view(1, 1, *kernel.size())
172
+ kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1))
173
+
174
+ self.weight = kernel.to(device)
175
+
176
+ def __call__(self, input):
177
+ """
178
+ Arguments:
179
+ Apply gaussian filter to input.
180
+ input (torch.Tensor): Input to apply gaussian filter on.
181
+ Returns:
182
+ filtered (torch.Tensor): Filtered output.
183
+ """
184
+ return F.conv2d(input, weight=self.weight.to(input.dtype))
185
+
186
 
187
  class SemanticStableDiffusionPipeline(DiffusionPipeline):
188
  r"""
 
371
  latents = latents * self.scheduler.init_noise_sigma
372
  return latents
373
 
374
+ def prepare_unet(self, attention_store, PnP: bool):
375
+ attn_procs = {}
376
+ for name in self.unet.attn_processors.keys():
377
+ if name.startswith("mid_block"):
378
+ place_in_unet = "mid"
379
+ elif name.startswith("up_blocks"):
380
+ place_in_unet = "up"
381
+ elif name.startswith("down_blocks"):
382
+ place_in_unet = "down"
383
+ else:
384
+ continue
385
+
386
+ if "attn2" in name:
387
+ attn_procs[name] = CrossAttnProcessor(
388
+ attention_store=attention_store,
389
+ place_in_unet=place_in_unet,
390
+ PnP=PnP,
391
+ editing_prompts=self.enabled_editing_prompts)
392
+ else:
393
+ attn_procs[name] = AttnProcessor()
394
+
395
+ self.unet.set_attn_processor(attn_procs)
396
+
397
  @torch.no_grad()
398
  def __call__(
399
  self,
 
422
  edit_mom_beta: Optional[float] = 0.4,
423
  edit_weights: Optional[List[float]] = None,
424
  sem_guidance: Optional[List[torch.Tensor]] = None,
425
+ # masking
426
+ use_cross_attn_mask: bool = False,
427
+ use_intersect_mask: bool = True,
428
+ edit_tokens_for_attn_map: List[str] = None,
429
+ # Attention store (just for visualization purposes)
430
+ attn_store_steps: Optional[List[int]] = [],
431
+ store_averaged_over_steps: bool = True,
432
  # DDPM additions
433
  use_ddpm: bool = False,
434
  wts: Optional[List[torch.Tensor]] = None,
 
527
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
528
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
529
  """
530
+ if use_intersect_mask:
531
+ use_cross_attn_mask = True
532
+
533
+ if use_cross_attn_mask:
534
+ self.smoothing = GaussianSmoothing(self.device)
535
+
536
  # 0. Default height and width to unet
537
  height = height or self.unet.config.sample_size * self.vae_scale_factor
538
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
547
  enable_edit_guidance = True
548
  if isinstance(editing_prompt, str):
549
  editing_prompt = [editing_prompt]
550
+ self.enabled_editing_prompts = len(editing_prompt)
551
  elif editing_prompt_embeddings is not None:
552
  enable_edit_guidance = True
553
+ self.enabled_editing_prompts = editing_prompt_embeddings.shape[0]
554
  else:
555
+ self.enabled_editing_prompts = 0
556
  enable_edit_guidance = False
557
 
558
  # get prompt text embeddings
 
560
  prompt,
561
  padding="max_length",
562
  max_length=self.tokenizer.model_max_length,
563
+ truncation=True,
564
  return_tensors="pt",
565
  )
566
  text_input_ids = text_inputs.input_ids
567
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
568
 
569
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
570
+ text_input_ids, untruncated_ids
571
+ ):
572
+ removed_text = self.tokenizer.batch_decode(
573
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
574
+ )
575
  logger.warning(
576
  "The following part of your input was truncated because CLIP can only handle sequences up to"
577
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
578
  )
579
+
580
  text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
581
 
582
  # duplicate text embeddings for each generation per prompt, using mps friendly method
 
587
  if enable_edit_guidance:
588
  # get safety text embeddings
589
  if editing_prompt_embeddings is None:
590
+ if edit_tokens_for_attn_map is not None:
591
+ edit_tokens = [[word.replace("</w>", "") for word in self.tokenizer.tokenize(item)] for item in editing_prompt]
592
+ #print(f"edit_tokens: {edit_tokens}")
593
+
594
  edit_concepts_input = self.tokenizer(
595
  [x for item in editing_prompt for x in repeat(item, batch_size)],
596
  padding="max_length",
597
  max_length=self.tokenizer.model_max_length,
598
+ truncation=True,
599
  return_tensors="pt",
600
+ return_length=True
601
  )
602
 
603
+ num_edit_tokens = edit_concepts_input.length -2 # not counting startoftext and endoftext
604
  edit_concepts_input_ids = edit_concepts_input.input_ids
605
+ untruncated_ids = self.tokenizer(
606
+ [x for item in editing_prompt for x in repeat(item, batch_size)],
607
+ padding="longest",
608
+ return_tensors="pt").input_ids
609
 
610
+ if untruncated_ids.shape[-1] >= edit_concepts_input_ids.shape[-1] and not torch.equal(
611
+ edit_concepts_input_ids, untruncated_ids
612
+ ):
613
  removed_text = self.tokenizer.batch_decode(
614
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
615
  )
616
  logger.warning(
617
  "The following part of your input was truncated because CLIP can only handle sequences up to"
618
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
619
  )
620
+
621
  edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
622
  else:
623
  edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
 
671
  # For classifier free guidance, we need to do two forward passes.
672
  # Here we concatenate the unconditional and text embeddings into a single batch
673
  # to avoid doing two forward passes
674
+ self.text_cross_attention_maps = [prompt] if isinstance(prompt, str) else prompt
675
  if enable_edit_guidance:
676
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
677
+ self.text_cross_attention_maps += \
678
+ ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
679
  else:
680
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
681
  # get the initial random noise unless the user supplied it
 
687
  t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
688
  timesteps = timesteps[-zs.shape[0]:]
689
 
690
+ self.attention_store = AttentionStore(average=store_averaged_over_steps)
691
+ self.prepare_unet(self.attention_store, False)
692
+
693
  # 5. Prepare latent variables
694
  num_channels_latents = self.unet.config.in_channels
695
  latents = self.prepare_latents(
 
717
  for i, t in enumerate(self.progress_bar(timesteps)):
718
  # expand the latents if we are doing classifier free guidance
719
  latent_model_input = (
720
+ torch.cat([latents] * (2 + self.enabled_editing_prompts)) if do_classifier_free_guidance else latents
721
  )
722
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
723
 
 
726
 
727
  # perform guidance
728
  if do_classifier_free_guidance:
729
+ noise_pred_out = noise_pred.chunk(2 + self.enabled_editing_prompts) # [b,4, 64, 64]
730
  noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
731
  noise_pred_edit_concepts = noise_pred_out[2:]
732
 
 
813
 
814
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
815
 
816
+ if use_cross_attn_mask:
817
+ out = self.attention_store.aggregate_attention(
818
+ attention_maps=self.attention_store.step_store,
819
+ prompts=self.text_cross_attention_maps,
820
+ res=16,
821
+ from_where=["up","down"],
822
+ is_cross=True,
823
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
824
  )
825
+
826
+ attn_map = out[:, :, 1:] # 0 -> startoftext
827
+ attn_map *= 100
828
+ attn_map = torch.nn.functional.softmax(attn_map, dim=-1)
829
+ attn_map = attn_map[:,:,:num_edit_tokens[c]] # -1 -> endoftext
830
+
831
+ assert(attn_map.shape[2]==num_edit_tokens[c])
832
+ if edit_tokens_for_attn_map is not None:
833
+ # select attn_map for specified tokens
834
+ token_idx = [edit_tokens[c].index(item) for item in edit_tokens_for_attn_map[c]]
835
+ attn_map = attn_map[:,:,token_idx]
836
+ assert(attn_map.shape[2] == len(edit_tokens_for_attn_map[c]))
837
+
838
+ # average over tokens
839
+ attn_map = torch.sum(attn_map, dim=2)
840
+
841
+ # gaussian_smoothing
842
+ attn_map = F.pad(attn_map.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect")
843
+ attn_map = self.smoothing(attn_map).squeeze(0).squeeze(0)
844
+
845
+ # torch.quantile function expects float32
846
+ if attn_map.dtype == torch.float32:
847
+ tmp = torch.quantile(
848
+ attn_map.flatten(),
849
+ edit_threshold_c
850
+ )
851
+ else:
852
+ tmp = torch.quantile(
853
+ attn_map.flatten().to(torch.float32),
854
+ edit_threshold_c
855
+ ).to(attn_map.dtype)
856
+
857
+ attn_mask = torch.where(attn_map >= tmp, 1.0, 0.0)
858
+
859
+ # resolution must match latent space dimension
860
+ attn_mask = F.interpolate(
861
+ attn_mask.unsqueeze(0).unsqueeze(0),
862
+ noise_guidance_edit_tmp.shape[-2:] # 64,64
863
+ )[0,0,:,:]
864
+
865
+ if not use_intersect_mask:
866
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
867
+
868
+ if use_intersect_mask:
869
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
870
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True)
871
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,4,1,1)
872
+
873
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
874
+ tmp = torch.quantile(
875
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
876
+ edit_threshold_c,
877
+ dim=2,
878
+ keepdim=False,
879
+ )
880
+ else:
881
+ tmp = torch.quantile(
882
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
883
+ edit_threshold_c,
884
+ dim=2,
885
+ keepdim=False,
886
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
887
+
888
+ sega_mask = torch.where(
889
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
890
+ torch.ones_like(noise_guidance_edit_tmp),
891
+ torch.zeros_like(noise_guidance_edit_tmp),
892
+ )
893
+
894
+ intersect_mask = sega_mask * attn_mask
895
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
896
+
897
+ elif not use_cross_attn_mask:
898
+ # calculate quantile
899
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
900
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True)
901
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,4,1,1)
902
+
903
+ # torch.quantile function expects float32
904
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
905
+ tmp = torch.quantile(
906
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
907
+ edit_threshold_c,
908
+ dim=2,
909
+ keepdim=False,
910
+ )
911
+ else:
912
+ tmp = torch.quantile(
913
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
914
+ edit_threshold_c,
915
+ dim=2,
916
+ keepdim=False,
917
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
918
+
919
+ noise_guidance_edit_tmp = torch.where(
920
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
921
+ noise_guidance_edit_tmp,
922
+ torch.zeros_like(noise_guidance_edit_tmp),
923
+ )
924
+
925
  noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
926
 
927
  # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
 
1024
  else: #if not use_ddpm:
1025
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1026
 
1027
+ # step callback
1028
+ store_step = i in attn_store_steps
1029
+ if store_step:
1030
+ print("storing attention")
1031
+ self.attention_store.between_steps(store_step)
1032
+
1033
  # call the callback, if provided
1034
  if callback is not None and i % callback_steps == 0:
1035
  callback(i, t, latents)