skytnt commited on
Commit
5c94345
1 Parent(s): d013b15

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +60 -25
pipeline.py CHANGED
@@ -1,18 +1,20 @@
1
  import inspect
2
  import re
3
- import PIL
 
4
  import numpy as np
5
  import torch
6
- from typing import Callable, List, Optional, Union
7
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
8
 
 
9
  from diffusers.configuration_utils import FrozenDict
10
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
  from diffusers.pipeline_utils import DiffusionPipeline
12
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13
- from diffusers.utils import deprecate, logging
14
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
15
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
 
 
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
@@ -130,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
130
  """
131
  tokens = []
132
  weights = []
 
133
  for text in prompt:
134
  texts_and_weights = parse_prompt_attention(text)
135
  text_token = []
@@ -138,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
138
  # tokenize and discard the starting and the ending token
139
  token = pipe.tokenizer(word).input_ids[1:-1]
140
  text_token += token
141
-
142
  # copy the weight by length of token
143
  text_weight += [weight] * len(token)
144
-
145
  # stop if the text is too long (longer than truncation limit)
146
  if len(text_token) > max_length:
 
147
  break
148
-
149
  # truncate
150
  if len(text_token) > max_length:
 
151
  text_token = text_token[:max_length]
152
  text_weight = text_weight[:max_length]
153
-
154
  tokens.append(text_token)
155
  weights.append(text_weight)
 
 
156
  return tokens, weights
157
 
158
 
@@ -171,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
171
  if len(weights[i]) == 0:
172
  w = [1.0] * weights_length
173
  else:
174
- for j in range((len(weights[i]) - 1) // chunk_length + 1):
175
  w.append(1.0) # weight for starting token in this chunk
176
- w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
177
  w.append(1.0) # weight for ending token in this chunk
178
  w += [1.0] * (weights_length - len(w))
179
  weights[i] = w[:]
@@ -182,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
182
 
183
 
184
  def get_unweighted_text_embeddings(
185
- pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
 
 
 
186
  ):
187
  """
188
  When the length of tokens is a multiple of the capacity of the text encoder,
@@ -283,7 +289,8 @@ def get_weighted_text_embeddings(
283
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
284
 
285
  max_embeddings_multiples = min(
286
- max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
 
287
  )
288
  max_embeddings_multiples = max(1, max_embeddings_multiples)
289
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -315,12 +322,18 @@ def get_weighted_text_embeddings(
315
 
316
  # get the embeddings
317
  text_embeddings = get_unweighted_text_embeddings(
318
- pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
 
 
 
319
  )
320
  prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
321
  if uncond_prompt is not None:
322
  uncond_embeddings = get_unweighted_text_embeddings(
323
- pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
 
 
 
324
  )
325
  uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
326
 
@@ -630,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
630
  # Unlike in other pipelines, latents need to be generated in the target device
631
  # for 1-to-1 results reproducibility with the CompVis implementation.
632
  # However this currently doesn't work in `mps`.
633
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
 
 
 
 
 
634
 
635
  if latents is None:
636
  if self.device.type == "mps":
637
  # randn does not exist on mps
638
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
639
- self.device
640
- )
 
 
 
641
  else:
642
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
 
 
 
 
 
643
  else:
644
  if latents.shape != latents_shape:
645
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -682,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
682
  # add noise to latents using the timesteps
683
  if self.device.type == "mps":
684
  # randn does not exist on mps
685
- noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
686
- self.device
687
- )
 
 
 
688
  else:
689
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
 
 
 
 
 
690
  latents = self.scheduler.add_noise(init_latents, noise, timesteps)
691
 
692
  t_start = max(num_inference_steps - init_timestep + offset, 0)
@@ -739,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
739
  self.device
740
  )
741
  image, has_nsfw_concept = self.safety_checker(
742
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
 
743
  )
744
  else:
745
  has_nsfw_concept = None
 
1
  import inspect
2
  import re
3
+ from typing import Callable, List, Optional, Union
4
+
5
  import numpy as np
6
  import torch
 
 
7
 
8
+ import PIL
9
  from diffusers.configuration_utils import FrozenDict
10
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
  from diffusers.pipeline_utils import DiffusionPipeline
 
 
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
+ from diffusers.utils import deprecate, logging
16
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
+
18
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
 
132
  """
133
  tokens = []
134
  weights = []
135
+ truncated = False
136
  for text in prompt:
137
  texts_and_weights = parse_prompt_attention(text)
138
  text_token = []
 
141
  # tokenize and discard the starting and the ending token
142
  token = pipe.tokenizer(word).input_ids[1:-1]
143
  text_token += token
 
144
  # copy the weight by length of token
145
  text_weight += [weight] * len(token)
 
146
  # stop if the text is too long (longer than truncation limit)
147
  if len(text_token) > max_length:
148
+ truncated = True
149
  break
 
150
  # truncate
151
  if len(text_token) > max_length:
152
+ truncated = True
153
  text_token = text_token[:max_length]
154
  text_weight = text_weight[:max_length]
 
155
  tokens.append(text_token)
156
  weights.append(text_weight)
157
+ if truncated:
158
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
159
  return tokens, weights
160
 
161
 
 
174
  if len(weights[i]) == 0:
175
  w = [1.0] * weights_length
176
  else:
177
+ for j in range(max_embeddings_multiples):
178
  w.append(1.0) # weight for starting token in this chunk
179
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
180
  w.append(1.0) # weight for ending token in this chunk
181
  w += [1.0] * (weights_length - len(w))
182
  weights[i] = w[:]
 
185
 
186
 
187
  def get_unweighted_text_embeddings(
188
+ pipe: DiffusionPipeline,
189
+ text_input: torch.Tensor,
190
+ chunk_length: int,
191
+ no_boseos_middle: Optional[bool] = True,
192
  ):
193
  """
194
  When the length of tokens is a multiple of the capacity of the text encoder,
 
289
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
290
 
291
  max_embeddings_multiples = min(
292
+ max_embeddings_multiples,
293
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
294
  )
295
  max_embeddings_multiples = max(1, max_embeddings_multiples)
296
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
 
322
 
323
  # get the embeddings
324
  text_embeddings = get_unweighted_text_embeddings(
325
+ pipe,
326
+ prompt_tokens,
327
+ pipe.tokenizer.model_max_length,
328
+ no_boseos_middle=no_boseos_middle,
329
  )
330
  prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
331
  if uncond_prompt is not None:
332
  uncond_embeddings = get_unweighted_text_embeddings(
333
+ pipe,
334
+ uncond_tokens,
335
+ pipe.tokenizer.model_max_length,
336
+ no_boseos_middle=no_boseos_middle,
337
  )
338
  uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
339
 
 
643
  # Unlike in other pipelines, latents need to be generated in the target device
644
  # for 1-to-1 results reproducibility with the CompVis implementation.
645
  # However this currently doesn't work in `mps`.
646
+ latents_shape = (
647
+ batch_size * num_images_per_prompt,
648
+ self.unet.in_channels,
649
+ height // 8,
650
+ width // 8,
651
+ )
652
 
653
  if latents is None:
654
  if self.device.type == "mps":
655
  # randn does not exist on mps
656
+ latents = torch.randn(
657
+ latents_shape,
658
+ generator=generator,
659
+ device="cpu",
660
+ dtype=latents_dtype,
661
+ ).to(self.device)
662
  else:
663
+ latents = torch.randn(
664
+ latents_shape,
665
+ generator=generator,
666
+ device=self.device,
667
+ dtype=latents_dtype,
668
+ )
669
  else:
670
  if latents.shape != latents_shape:
671
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
 
708
  # add noise to latents using the timesteps
709
  if self.device.type == "mps":
710
  # randn does not exist on mps
711
+ noise = torch.randn(
712
+ init_latents.shape,
713
+ generator=generator,
714
+ device="cpu",
715
+ dtype=latents_dtype,
716
+ ).to(self.device)
717
  else:
718
+ noise = torch.randn(
719
+ init_latents.shape,
720
+ generator=generator,
721
+ device=self.device,
722
+ dtype=latents_dtype,
723
+ )
724
  latents = self.scheduler.add_noise(init_latents, noise, timesteps)
725
 
726
  t_start = max(num_inference_steps - init_timestep + offset, 0)
 
773
  self.device
774
  )
775
  image, has_nsfw_concept = self.safety_checker(
776
+ images=image,
777
+ clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
778
  )
779
  else:
780
  has_nsfw_concept = None