hyoungwoncho commited on
Commit
2315bef
1 Parent(s): 22808d8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +23 -48
pipeline.py CHANGED
@@ -38,10 +38,8 @@ EXAMPLE_DOC_STRING = """
38
  ```py
39
  >>> import torch
40
  >>> from diffusers import StableDiffusionPipeline
41
-
42
  >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
43
  >>> pipe = pipe.to("cuda")
44
-
45
  >>> prompt = "a photo of an astronaut riding a horse on mars"
46
  >>> image = pipe(prompt).images[0]
47
  ```
@@ -64,8 +62,12 @@ class PAGIdentitySelfAttnProcessor:
64
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
65
  attention_mask: Optional[torch.FloatTensor] = None,
66
  temb: Optional[torch.FloatTensor] = None,
67
- scale: float = 1.0,
 
68
  ) -> torch.FloatTensor:
 
 
 
69
 
70
  residual = hidden_states
71
  if attn.spatial_norm is not None:
@@ -91,11 +93,9 @@ class PAGIdentitySelfAttnProcessor:
91
  if attn.group_norm is not None:
92
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
93
 
94
- args = () if USE_PEFT_BACKEND else (scale,)
95
-
96
- query = attn.to_q(hidden_states_org, *args)
97
- key = attn.to_k(hidden_states_org, *args)
98
- value = attn.to_v(hidden_states_org, *args)
99
 
100
  inner_dim = key.shape[-1]
101
  head_dim = inner_dim // attn.heads
@@ -115,7 +115,7 @@ class PAGIdentitySelfAttnProcessor:
115
  hidden_states_org = hidden_states_org.to(query.dtype)
116
 
117
  # linear proj
118
- hidden_states_org = attn.to_out[0](hidden_states_org, *args)
119
  # dropout
120
  hidden_states_org = attn.to_out[1](hidden_states_org)
121
 
@@ -134,9 +134,7 @@ class PAGIdentitySelfAttnProcessor:
134
  if attn.group_norm is not None:
135
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
136
 
137
- args = () if USE_PEFT_BACKEND else (scale,)
138
-
139
- value = attn.to_v(hidden_states_ptb, *args)
140
 
141
  hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
142
  #hidden_states_ptb = value
@@ -144,7 +142,7 @@ class PAGIdentitySelfAttnProcessor:
144
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
145
 
146
  # linear proj
147
- hidden_states_ptb = attn.to_out[0](hidden_states_ptb, *args)
148
  # dropout
149
  hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
150
 
@@ -178,8 +176,12 @@ class PAGCFGIdentitySelfAttnProcessor:
178
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
179
  attention_mask: Optional[torch.FloatTensor] = None,
180
  temb: Optional[torch.FloatTensor] = None,
181
- scale: float = 1.0,
 
182
  ) -> torch.FloatTensor:
 
 
 
183
 
184
  residual = hidden_states
185
  if attn.spatial_norm is not None:
@@ -205,12 +207,10 @@ class PAGCFGIdentitySelfAttnProcessor:
205
 
206
  if attn.group_norm is not None:
207
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
208
-
209
- args = () if USE_PEFT_BACKEND else (scale,)
210
 
211
- query = attn.to_q(hidden_states_org, *args)
212
- key = attn.to_k(hidden_states_org, *args)
213
- value = attn.to_v(hidden_states_org, *args)
214
 
215
  inner_dim = key.shape[-1]
216
  head_dim = inner_dim // attn.heads
@@ -230,7 +230,7 @@ class PAGCFGIdentitySelfAttnProcessor:
230
  hidden_states_org = hidden_states_org.to(query.dtype)
231
 
232
  # linear proj
233
- hidden_states_org = attn.to_out[0](hidden_states_org, *args)
234
  # dropout
235
  hidden_states_org = attn.to_out[1](hidden_states_org)
236
 
@@ -249,14 +249,12 @@ class PAGCFGIdentitySelfAttnProcessor:
249
  if attn.group_norm is not None:
250
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
251
 
252
- args = () if USE_PEFT_BACKEND else (scale,)
253
-
254
- value = attn.to_v(hidden_states_ptb, *args)
255
  hidden_states_ptb = value
256
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
257
 
258
  # linear proj
259
- hidden_states_ptb = attn.to_out[0](hidden_states_ptb, *args)
260
  # dropout
261
  hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
262
 
@@ -298,7 +296,6 @@ def retrieve_timesteps(
298
  """
299
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
300
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
301
-
302
  Args:
303
  scheduler (`SchedulerMixin`):
304
  The scheduler to get timesteps from.
@@ -311,7 +308,6 @@ def retrieve_timesteps(
311
  Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
312
  timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
313
  must be `None`.
314
-
315
  Returns:
316
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
317
  second element is the number of inference steps.
@@ -332,22 +328,19 @@ def retrieve_timesteps(
332
  return timesteps, num_inference_steps
333
 
334
 
335
- class StableDiffusionPAGPipeline(
336
  DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
337
  ):
338
  r"""
339
  Pipeline for text-to-image generation using Stable Diffusion.
340
-
341
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
342
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
343
-
344
  The pipeline also inherits the following loading methods:
345
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
346
  - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
347
  - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
348
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
349
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
350
-
351
  Args:
352
  vae ([`AutoencoderKL`]):
353
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -540,7 +533,6 @@ class StableDiffusionPAGPipeline(
540
  ):
541
  r"""
542
  Encodes the prompt into text encoder hidden states.
543
-
544
  Args:
545
  prompt (`str` or `List[str]`, *optional*):
546
  prompt to be encoded
@@ -885,12 +877,9 @@ class StableDiffusionPAGPipeline(
885
 
886
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
887
  r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
888
-
889
  The suffixes after the scaling factors represent the stages where they are being applied.
890
-
891
  Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
892
  that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
893
-
894
  Args:
895
  s1 (`float`):
896
  Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
@@ -914,13 +903,9 @@ class StableDiffusionPAGPipeline(
914
  """
915
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
916
  key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
917
-
918
  <Tip warning={true}>
919
-
920
  This API is 🧪 experimental.
921
-
922
  </Tip>
923
-
924
  Args:
925
  unet (`bool`, defaults to `True`): To apply fusion on the UNet.
926
  vae (`bool`, defaults to `True`): To apply fusion on the VAE.
@@ -944,17 +929,12 @@ class StableDiffusionPAGPipeline(
944
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
945
  def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
946
  """Disable QKV projection fusion if enabled.
947
-
948
  <Tip warning={true}>
949
-
950
  This API is 🧪 experimental.
951
-
952
  </Tip>
953
-
954
  Args:
955
  unet (`bool`, defaults to `True`): To apply fusion on the UNet.
956
  vae (`bool`, defaults to `True`): To apply fusion on the VAE.
957
-
958
  """
959
  if unet:
960
  if not self.fusing_unet:
@@ -974,7 +954,6 @@ class StableDiffusionPAGPipeline(
974
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
975
  """
976
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
977
-
978
  Args:
979
  timesteps (`torch.Tensor`):
980
  generate embedding vectors at these timesteps
@@ -982,7 +961,6 @@ class StableDiffusionPAGPipeline(
982
  dimension of the embeddings to generate
983
  dtype:
984
  data type of the generated embeddings
985
-
986
  Returns:
987
  `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
988
  """
@@ -1128,7 +1106,6 @@ class StableDiffusionPAGPipeline(
1128
  ):
1129
  r"""
1130
  The call function to the pipeline for generation.
1131
-
1132
  Args:
1133
  prompt (`str` or `List[str]`, *optional*):
1134
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
@@ -1195,9 +1172,7 @@ class StableDiffusionPAGPipeline(
1195
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1196
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1197
  `._callback_tensor_inputs` attribute of your pipeline class.
1198
-
1199
  Examples:
1200
-
1201
  Returns:
1202
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1203
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
 
38
  ```py
39
  >>> import torch
40
  >>> from diffusers import StableDiffusionPipeline
 
41
  >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
42
  >>> pipe = pipe.to("cuda")
 
43
  >>> prompt = "a photo of an astronaut riding a horse on mars"
44
  >>> image = pipe(prompt).images[0]
45
  ```
 
62
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
63
  attention_mask: Optional[torch.FloatTensor] = None,
64
  temb: Optional[torch.FloatTensor] = None,
65
+ *args,
66
+ **kwargs,
67
  ) -> torch.FloatTensor:
68
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
69
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
70
+ deprecate("scale", "1.0.0", deprecation_message)
71
 
72
  residual = hidden_states
73
  if attn.spatial_norm is not None:
 
93
  if attn.group_norm is not None:
94
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
95
 
96
+ query = attn.to_q(hidden_states_org)
97
+ key = attn.to_k(hidden_states_org)
98
+ value = attn.to_v(hidden_states_org)
 
 
99
 
100
  inner_dim = key.shape[-1]
101
  head_dim = inner_dim // attn.heads
 
115
  hidden_states_org = hidden_states_org.to(query.dtype)
116
 
117
  # linear proj
118
+ hidden_states_org = attn.to_out[0](hidden_states_org)
119
  # dropout
120
  hidden_states_org = attn.to_out[1](hidden_states_org)
121
 
 
134
  if attn.group_norm is not None:
135
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
136
 
137
+ value = attn.to_v(hidden_states_ptb)
 
 
138
 
139
  hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
140
  #hidden_states_ptb = value
 
142
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
143
 
144
  # linear proj
145
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
146
  # dropout
147
  hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
148
 
 
176
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
177
  attention_mask: Optional[torch.FloatTensor] = None,
178
  temb: Optional[torch.FloatTensor] = None,
179
+ *args,
180
+ **kwargs,
181
  ) -> torch.FloatTensor:
182
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
183
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
184
+ deprecate("scale", "1.0.0", deprecation_message)
185
 
186
  residual = hidden_states
187
  if attn.spatial_norm is not None:
 
207
 
208
  if attn.group_norm is not None:
209
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
 
 
210
 
211
+ query = attn.to_q(hidden_states_org)
212
+ key = attn.to_k(hidden_states_org)
213
+ value = attn.to_v(hidden_states_org)
214
 
215
  inner_dim = key.shape[-1]
216
  head_dim = inner_dim // attn.heads
 
230
  hidden_states_org = hidden_states_org.to(query.dtype)
231
 
232
  # linear proj
233
+ hidden_states_org = attn.to_out[0](hidden_states_org)
234
  # dropout
235
  hidden_states_org = attn.to_out[1](hidden_states_org)
236
 
 
249
  if attn.group_norm is not None:
250
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
251
 
252
+ value = attn.to_v(hidden_states_ptb)
 
 
253
  hidden_states_ptb = value
254
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
255
 
256
  # linear proj
257
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
258
  # dropout
259
  hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
260
 
 
296
  """
297
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
298
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
 
299
  Args:
300
  scheduler (`SchedulerMixin`):
301
  The scheduler to get timesteps from.
 
308
  Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
309
  timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
310
  must be `None`.
 
311
  Returns:
312
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
313
  second element is the number of inference steps.
 
328
  return timesteps, num_inference_steps
329
 
330
 
331
+ class StableDiffusionPipeline(
332
  DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
333
  ):
334
  r"""
335
  Pipeline for text-to-image generation using Stable Diffusion.
 
336
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
337
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
 
338
  The pipeline also inherits the following loading methods:
339
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
340
  - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
341
  - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
342
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
343
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
 
344
  Args:
345
  vae ([`AutoencoderKL`]):
346
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
 
533
  ):
534
  r"""
535
  Encodes the prompt into text encoder hidden states.
 
536
  Args:
537
  prompt (`str` or `List[str]`, *optional*):
538
  prompt to be encoded
 
877
 
878
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
879
  r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
 
880
  The suffixes after the scaling factors represent the stages where they are being applied.
 
881
  Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
882
  that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
 
883
  Args:
884
  s1 (`float`):
885
  Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
 
903
  """
904
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
905
  key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
 
906
  <Tip warning={true}>
 
907
  This API is 🧪 experimental.
 
908
  </Tip>
 
909
  Args:
910
  unet (`bool`, defaults to `True`): To apply fusion on the UNet.
911
  vae (`bool`, defaults to `True`): To apply fusion on the VAE.
 
929
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
930
  def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
931
  """Disable QKV projection fusion if enabled.
 
932
  <Tip warning={true}>
 
933
  This API is 🧪 experimental.
 
934
  </Tip>
 
935
  Args:
936
  unet (`bool`, defaults to `True`): To apply fusion on the UNet.
937
  vae (`bool`, defaults to `True`): To apply fusion on the VAE.
 
938
  """
939
  if unet:
940
  if not self.fusing_unet:
 
954
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
955
  """
956
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
 
957
  Args:
958
  timesteps (`torch.Tensor`):
959
  generate embedding vectors at these timesteps
 
961
  dimension of the embeddings to generate
962
  dtype:
963
  data type of the generated embeddings
 
964
  Returns:
965
  `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
966
  """
 
1106
  ):
1107
  r"""
1108
  The call function to the pipeline for generation.
 
1109
  Args:
1110
  prompt (`str` or `List[str]`, *optional*):
1111
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
 
1172
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1173
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1174
  `._callback_tensor_inputs` attribute of your pipeline class.
 
1175
  Examples:
 
1176
  Returns:
1177
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1178
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,