AlekseyCalvin commited on
Commit
57ecb18
1 Parent(s): 0369c77

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +56 -14
pipeline.py CHANGED
@@ -100,6 +100,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
100
  def _get_t5_prompt_embeds(
101
  self,
102
  prompt: Union[str, List[str]] = None,
 
103
  num_images_per_prompt: int = 1,
104
  max_sequence_length: int = 512,
105
  device: Optional[torch.device] = None,
@@ -156,6 +157,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
156
 
157
  text_inputs = self.tokenizer(
158
  prompt,
 
159
  padding="max_length",
160
  max_length=self.tokenizer_max_length,
161
  truncation=True,
@@ -188,14 +190,18 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
188
  self,
189
  prompt: Union[str, List[str]],
190
  prompt_2: Union[str, List[str]],
 
191
  negative_prompt: Optional[Union[str, List[str]]] = None,
192
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
193
  device: Optional[torch.device] = None,
194
  num_images_per_prompt: int = 1,
195
  prompt_embeds: Optional[torch.FloatTensor] = None,
 
196
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
197
  max_sequence_length: int = 512,
198
  lora_scale: Optional[float] = None,
 
199
  ):
200
  r"""
201
 
@@ -232,7 +238,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
232
  scale_lora_layers(self.text_encoder_2, lora_scale)
233
 
234
  prompt = [prompt] if isinstance(prompt, str) else prompt
235
- negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
236
 
237
  if prompt_embeds is None:
238
  prompt_2 = prompt_2 or prompt
@@ -251,16 +256,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
251
  device=device,
252
  )
253
 
254
- if self.text_encoder is not None:
255
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
256
- # Retrieve the original scale by scaling back the LoRA layers
257
- unscale_lora_layers(self.text_encoder, lora_scale)
258
-
259
- if self.text_encoder_2 is not None:
260
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
261
- # Retrieve the original scale by scaling back the LoRA layers
262
- unscale_lora_layers(self.text_encoder_2, lora_scale)
263
-
264
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
265
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
266
 
@@ -270,9 +265,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
270
  self,
271
  prompt,
272
  prompt_2,
273
- negative_prompt,
274
  height,
275
  width,
 
276
  prompt_embeds=None,
277
  pooled_prompt_embeds=None,
278
  callback_on_step_end_tensor_inputs=None,
@@ -311,10 +307,56 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
311
  raise ValueError(
312
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
313
  )
 
 
314
 
315
  if max_sequence_length is not None and max_sequence_length > 512:
316
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
 
 
318
  @staticmethod
319
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
320
  latent_image_ids = torch.zeros(height // 2, width // 2, 3)
@@ -437,7 +479,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
437
  prompt_2: Optional[Union[str, List[str]]] = None,
438
  height: Optional[int] = None,
439
  width: Optional[int] = None,
440
- negative_prompt: Union[str, List[str]] = None,
441
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
442
  num_inference_steps: int = 4,
443
  timesteps: List[int] = None,
@@ -457,7 +499,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
457
  ):
458
  height = height or self.default_sample_size * self.vae_scale_factor
459
  width = width or self.default_sample_size * self.vae_scale_factor
460
-
461
  # 1. Check inputs
462
  self.check_inputs(
463
  prompt,
 
100
  def _get_t5_prompt_embeds(
101
  self,
102
  prompt: Union[str, List[str]] = None,
103
+ negative_prompt: Optional[Union[str, List[str]]] = None,
104
  num_images_per_prompt: int = 1,
105
  max_sequence_length: int = 512,
106
  device: Optional[torch.device] = None,
 
157
 
158
  text_inputs = self.tokenizer(
159
  prompt,
160
+ negative_prompt,
161
  padding="max_length",
162
  max_length=self.tokenizer_max_length,
163
  truncation=True,
 
190
  self,
191
  prompt: Union[str, List[str]],
192
  prompt_2: Union[str, List[str]],
193
+ do_classifier_free_guidance: bool = True,
194
  negative_prompt: Optional[Union[str, List[str]]] = None,
195
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
196
  device: Optional[torch.device] = None,
197
  num_images_per_prompt: int = 1,
198
  prompt_embeds: Optional[torch.FloatTensor] = None,
199
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
200
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
201
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
202
  max_sequence_length: int = 512,
203
  lora_scale: Optional[float] = None,
204
+ adapter_weights: Optional[float] = None,
205
  ):
206
  r"""
207
 
 
238
  scale_lora_layers(self.text_encoder_2, lora_scale)
239
 
240
  prompt = [prompt] if isinstance(prompt, str) else prompt
 
241
 
242
  if prompt_embeds is None:
243
  prompt_2 = prompt_2 or prompt
 
256
  device=device,
257
  )
258
 
 
 
 
 
 
 
 
 
 
 
259
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
260
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
261
 
 
265
  self,
266
  prompt,
267
  prompt_2,
268
+ negative_prompt=None,
269
  height,
270
  width,
271
+ lora_scale=None,
272
  prompt_embeds=None,
273
  pooled_prompt_embeds=None,
274
  callback_on_step_end_tensor_inputs=None,
 
307
  raise ValueError(
308
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
309
  )
310
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
311
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
312
 
313
  if max_sequence_length is not None and max_sequence_length > 512:
314
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
315
+
316
+ prompt_attention_mask = text_inputs.attention_mask
317
+ prompt_attention_mask = prompt_attention_mask.to(device)
318
+
319
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
320
+ prompt_embeds = prompt_embeds[0]
321
+
322
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
323
+ uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
324
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325
+ max_length = prompt_embeds.shape[1]
326
+ uncond_input = self.tokenizer(
327
+ uncond_tokens,
328
+ padding="max_length",
329
+ max_length=max_length,
330
+ truncation=True,
331
+ return_attention_mask=True,
332
+ add_special_tokens=True,
333
+ return_tensors="pt",
334
+ )
335
+ negative_prompt_attention_mask = uncond_input.attention_mask
336
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
337
+
338
+ negative_prompt_embeds = self.text_encoder(
339
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
340
+ )
341
+ negative_prompt_embeds = negative_prompt_embeds[0]
342
+
343
+ if do_classifier_free_guidance:
344
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
345
+ seq_len = negative_prompt_embeds.shape[1]
346
+
347
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348
+
349
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
351
+
352
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
354
+ else:
355
+ negative_prompt_embeds = None
356
+ negative_prompt_attention_mask = None
357
 
358
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
359
+
360
  @staticmethod
361
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
362
  latent_image_ids = torch.zeros(height // 2, width // 2, 3)
 
479
  prompt_2: Optional[Union[str, List[str]]] = None,
480
  height: Optional[int] = None,
481
  width: Optional[int] = None,
482
+ negative_prompt: Optional[Union[str, List[str]]] = None,
483
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
484
  num_inference_steps: int = 4,
485
  timesteps: List[int] = None,
 
499
  ):
500
  height = height or self.default_sample_size * self.vae_scale_factor
501
  width = width or self.default_sample_size * self.vae_scale_factor
502
+
503
  # 1. Check inputs
504
  self.check_inputs(
505
  prompt,