Spaces:
Running
on
Zero
Running
on
Zero
AlekseyCalvin
commited on
Commit
•
57ecb18
1
Parent(s):
0369c77
Update pipeline.py
Browse files- pipeline.py +56 -14
pipeline.py
CHANGED
@@ -100,6 +100,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
100 |
def _get_t5_prompt_embeds(
|
101 |
self,
|
102 |
prompt: Union[str, List[str]] = None,
|
|
|
103 |
num_images_per_prompt: int = 1,
|
104 |
max_sequence_length: int = 512,
|
105 |
device: Optional[torch.device] = None,
|
@@ -156,6 +157,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
156 |
|
157 |
text_inputs = self.tokenizer(
|
158 |
prompt,
|
|
|
159 |
padding="max_length",
|
160 |
max_length=self.tokenizer_max_length,
|
161 |
truncation=True,
|
@@ -188,14 +190,18 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
188 |
self,
|
189 |
prompt: Union[str, List[str]],
|
190 |
prompt_2: Union[str, List[str]],
|
|
|
191 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
192 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
193 |
device: Optional[torch.device] = None,
|
194 |
num_images_per_prompt: int = 1,
|
195 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
196 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
197 |
max_sequence_length: int = 512,
|
198 |
lora_scale: Optional[float] = None,
|
|
|
199 |
):
|
200 |
r"""
|
201 |
|
@@ -232,7 +238,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
232 |
scale_lora_layers(self.text_encoder_2, lora_scale)
|
233 |
|
234 |
prompt = [prompt] if isinstance(prompt, str) else prompt
|
235 |
-
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
236 |
|
237 |
if prompt_embeds is None:
|
238 |
prompt_2 = prompt_2 or prompt
|
@@ -251,16 +256,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
251 |
device=device,
|
252 |
)
|
253 |
|
254 |
-
if self.text_encoder is not None:
|
255 |
-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
256 |
-
# Retrieve the original scale by scaling back the LoRA layers
|
257 |
-
unscale_lora_layers(self.text_encoder, lora_scale)
|
258 |
-
|
259 |
-
if self.text_encoder_2 is not None:
|
260 |
-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
261 |
-
# Retrieve the original scale by scaling back the LoRA layers
|
262 |
-
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
263 |
-
|
264 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
265 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
266 |
|
@@ -270,9 +265,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
270 |
self,
|
271 |
prompt,
|
272 |
prompt_2,
|
273 |
-
negative_prompt,
|
274 |
height,
|
275 |
width,
|
|
|
276 |
prompt_embeds=None,
|
277 |
pooled_prompt_embeds=None,
|
278 |
callback_on_step_end_tensor_inputs=None,
|
@@ -311,10 +307,56 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
311 |
raise ValueError(
|
312 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
313 |
)
|
|
|
|
|
314 |
|
315 |
if max_sequence_length is not None and max_sequence_length > 512:
|
316 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
|
|
|
|
318 |
@staticmethod
|
319 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
320 |
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
@@ -437,7 +479,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
437 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
438 |
height: Optional[int] = None,
|
439 |
width: Optional[int] = None,
|
440 |
-
negative_prompt: Union[str, List[str]] = None,
|
441 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
442 |
num_inference_steps: int = 4,
|
443 |
timesteps: List[int] = None,
|
@@ -457,7 +499,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
457 |
):
|
458 |
height = height or self.default_sample_size * self.vae_scale_factor
|
459 |
width = width or self.default_sample_size * self.vae_scale_factor
|
460 |
-
|
461 |
# 1. Check inputs
|
462 |
self.check_inputs(
|
463 |
prompt,
|
|
|
100 |
def _get_t5_prompt_embeds(
|
101 |
self,
|
102 |
prompt: Union[str, List[str]] = None,
|
103 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
104 |
num_images_per_prompt: int = 1,
|
105 |
max_sequence_length: int = 512,
|
106 |
device: Optional[torch.device] = None,
|
|
|
157 |
|
158 |
text_inputs = self.tokenizer(
|
159 |
prompt,
|
160 |
+
negative_prompt,
|
161 |
padding="max_length",
|
162 |
max_length=self.tokenizer_max_length,
|
163 |
truncation=True,
|
|
|
190 |
self,
|
191 |
prompt: Union[str, List[str]],
|
192 |
prompt_2: Union[str, List[str]],
|
193 |
+
do_classifier_free_guidance: bool = True,
|
194 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
195 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
196 |
device: Optional[torch.device] = None,
|
197 |
num_images_per_prompt: int = 1,
|
198 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
199 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
200 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
201 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
202 |
max_sequence_length: int = 512,
|
203 |
lora_scale: Optional[float] = None,
|
204 |
+
adapter_weights: Optional[float] = None,
|
205 |
):
|
206 |
r"""
|
207 |
|
|
|
238 |
scale_lora_layers(self.text_encoder_2, lora_scale)
|
239 |
|
240 |
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
241 |
|
242 |
if prompt_embeds is None:
|
243 |
prompt_2 = prompt_2 or prompt
|
|
|
256 |
device=device,
|
257 |
)
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
260 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
261 |
|
|
|
265 |
self,
|
266 |
prompt,
|
267 |
prompt_2,
|
268 |
+
negative_prompt=None,
|
269 |
height,
|
270 |
width,
|
271 |
+
lora_scale=None,
|
272 |
prompt_embeds=None,
|
273 |
pooled_prompt_embeds=None,
|
274 |
callback_on_step_end_tensor_inputs=None,
|
|
|
307 |
raise ValueError(
|
308 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
309 |
)
|
310 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
311 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
312 |
|
313 |
if max_sequence_length is not None and max_sequence_length > 512:
|
314 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
315 |
+
|
316 |
+
prompt_attention_mask = text_inputs.attention_mask
|
317 |
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
318 |
+
|
319 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
320 |
+
prompt_embeds = prompt_embeds[0]
|
321 |
+
|
322 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
323 |
+
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
324 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
325 |
+
max_length = prompt_embeds.shape[1]
|
326 |
+
uncond_input = self.tokenizer(
|
327 |
+
uncond_tokens,
|
328 |
+
padding="max_length",
|
329 |
+
max_length=max_length,
|
330 |
+
truncation=True,
|
331 |
+
return_attention_mask=True,
|
332 |
+
add_special_tokens=True,
|
333 |
+
return_tensors="pt",
|
334 |
+
)
|
335 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
336 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
337 |
+
|
338 |
+
negative_prompt_embeds = self.text_encoder(
|
339 |
+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
340 |
+
)
|
341 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
342 |
+
|
343 |
+
if do_classifier_free_guidance:
|
344 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
345 |
+
seq_len = negative_prompt_embeds.shape[1]
|
346 |
+
|
347 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
348 |
+
|
349 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
350 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
351 |
+
|
352 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
353 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
354 |
+
else:
|
355 |
+
negative_prompt_embeds = None
|
356 |
+
negative_prompt_attention_mask = None
|
357 |
|
358 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
359 |
+
|
360 |
@staticmethod
|
361 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
362 |
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
|
479 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
480 |
height: Optional[int] = None,
|
481 |
width: Optional[int] = None,
|
482 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
483 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
484 |
num_inference_steps: int = 4,
|
485 |
timesteps: List[int] = None,
|
|
|
499 |
):
|
500 |
height = height or self.default_sample_size * self.vae_scale_factor
|
501 |
width = width or self.default_sample_size * self.vae_scale_factor
|
502 |
+
|
503 |
# 1. Check inputs
|
504 |
self.check_inputs(
|
505 |
prompt,
|