AlekseyCalvin commited on
Commit
8d33af5
1 Parent(s): 1b5843f

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +49 -2
custom_pipeline.py CHANGED
@@ -1,8 +1,20 @@
1
  import torch
2
  import numpy as np
3
  from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
 
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Constants for shift calculation
8
  BASE_SEQ_LEN = 256
@@ -54,6 +66,8 @@ class FluxWithCFGPipeline(FluxPipeline):
54
  prompt_2: Optional[Union[str, List[str]]] = None,
55
  height: Optional[int] = None,
56
  width: Optional[int] = None,
 
 
57
  num_inference_steps: int = 4,
58
  timesteps: List[int] = None,
59
  guidance_scale: float = 3.5,
@@ -62,6 +76,8 @@ class FluxWithCFGPipeline(FluxPipeline):
62
  latents: Optional[torch.FloatTensor] = None,
63
  prompt_embeds: Optional[torch.FloatTensor] = None,
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -102,6 +118,21 @@ class FluxWithCFGPipeline(FluxPipeline):
102
  max_sequence_length=max_sequence_length,
103
  lora_scale=lora_scale,
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # 4. Prepare latent variables
106
  num_channels_latents = self.transformer.config.in_channels // 4
107
  latents, latent_image_ids = self.prepare_latents(
@@ -114,6 +145,7 @@ class FluxWithCFGPipeline(FluxPipeline):
114
  generator,
115
  latents,
116
  )
 
117
  # 5. Prepare timesteps
118
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
119
  image_seq_len = latents.shape[1]
@@ -149,9 +181,24 @@ class FluxWithCFGPipeline(FluxPipeline):
149
  joint_attention_kwargs=self.joint_attention_kwargs,
150
  return_dict=False,
151
  )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Yield intermediate result
154
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
155
  torch.cuda.empty_cache()
156
 
157
  # Final image
 
1
  import torch
2
  import numpy as np
3
  from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
  from PIL import Image
7
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
8
+
9
+ from diffusers.utils import is_torch_xla_available
10
+
11
+ if is_torch_xla_available():
12
+ import torch_xla.core.xla_model as xm
13
+
14
+ XLA_AVAILABLE = True
15
+ else:
16
+ XLA_AVAILABLE = False
17
+
18
 
19
  # Constants for shift calculation
20
  BASE_SEQ_LEN = 256
 
66
  prompt_2: Optional[Union[str, List[str]]] = None,
67
  height: Optional[int] = None,
68
  width: Optional[int] = None,
69
+ negative_prompt: Optional[Union[str, List[str]]] = None,
70
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
71
  num_inference_steps: int = 4,
72
  timesteps: List[int] = None,
73
  guidance_scale: float = 3.5,
 
76
  latents: Optional[torch.FloatTensor] = None,
77
  prompt_embeds: Optional[torch.FloatTensor] = None,
78
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
79
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
80
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
81
  output_type: Optional[str] = "pil",
82
  return_dict: bool = True,
83
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
 
118
  max_sequence_length=max_sequence_length,
119
  lora_scale=lora_scale,
120
  )
121
+ (
122
+ negative_prompt_embeds,
123
+ negative_pooled_prompt_embeds,
124
+ negative_text_ids,
125
+ ) = self.encode_prompt(
126
+ prompt=negative_prompt,
127
+ prompt_2=negative_prompt_2,
128
+ prompt_embeds=negative_prompt_embeds,
129
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
130
+ device=device,
131
+ num_images_per_prompt=num_images_per_prompt,
132
+ max_sequence_length=max_sequence_length,
133
+ lora_scale=lora_scale,
134
+ )
135
+
136
  # 4. Prepare latent variables
137
  num_channels_latents = self.transformer.config.in_channels // 4
138
  latents, latent_image_ids = self.prepare_latents(
 
145
  generator,
146
  latents,
147
  )
148
+
149
  # 5. Prepare timesteps
150
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
151
  image_seq_len = latents.shape[1]
 
181
  joint_attention_kwargs=self.joint_attention_kwargs,
182
  return_dict=False,
183
  )[0]
184
+
185
+ noise_pred_uncond = self.transformer(
186
+ hidden_states=latents,
187
+ timestep=timestep / 1000,
188
+ guidance=guidance,
189
+ pooled_projections=negative_pooled_prompt_embeds,
190
+ encoder_hidden_states=negative_prompt_embeds,
191
+ txt_ids=negative_text_ids,
192
+ img_ids=latent_image_ids,
193
+ joint_attention_kwargs=self.joint_attention_kwargs,
194
+ return_dict=False,
195
+ )[0]
196
+
197
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
198
 
199
+ latents_dtype = latents.dtype
200
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
201
+ # Yield intermediate result
202
  torch.cuda.empty_cache()
203
 
204
  # Final image