dn6 HF Staff commited on
Commit
56f2217
·
verified ·
1 Parent(s): ed91f33

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +32 -0
  2. block.py +601 -0
  3. modular_config.json +7 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cached Flux Prompt Encoding
2
+
3
+ This is a custom block designed to cache input prompts for the Flux model. Prompts encoded prompts are stored in a safetensors file using the hashed prompt string as the key. Prompts existing in the cache are loaded directly from the file, while new prompts are encoded and added to the cache.
4
+
5
+ # How to use
6
+
7
+ ```python
8
+ import torch
9
+ from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
10
+ from diffusers.modular_pipelines.flux.modular_blocks import TEXT2IMAGE_BLOCKS
11
+
12
+ prompt_encoder_block = ModularPipelineBlocks.from_pretrained(
13
+ "diffusers/flux-cached-prompt-encoder-custom-block",
14
+ trust_remote_code=True
15
+ )
16
+
17
+ blocks = TEXT2IMAGE_BLOCKS.copy().insert("text_encoder", prompt_encoder_block, 0)
18
+ blocks = SequentialPipelineBlocks.from_blocks_dict(blocks)
19
+
20
+ repo_id = "diffusers/modular-FLUX.1-dev"
21
+ pipe = blocks.init_pipeline(repo_id)
22
+ pipe.load_components(torch_dtype=torch.bfloat16, device_map="cuda")
23
+
24
+ output = pipe(
25
+ prompt=prompt,
26
+ num_inference_steps=35,
27
+ guidance_scale=3.5,
28
+ output_type="pil",
29
+ )
30
+ image = output.values['image']
31
+ ```
32
+
block.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import torch
6
+ from diffusers import FluxModularPipeline, ModularPipelineBlocks
7
+ from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
8
+ from diffusers.modular_pipelines import PipelineState
9
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
10
+ ComponentSpec,
11
+ InputParam,
12
+ OutputParam,
13
+ )
14
+ from diffusers.utils import (
15
+ USE_PEFT_BACKEND,
16
+ logger,
17
+ scale_lora_layers,
18
+ unscale_lora_layers,
19
+ )
20
+ from safetensors import safe_open
21
+ from safetensors.torch import save_file
22
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
23
+
24
+
25
+ class CachedFluxTextEncoderStep(ModularPipelineBlocks):
26
+ model_name = "flux"
27
+
28
+ def __init__(
29
+ self,
30
+ use_cache: bool = True,
31
+ cache_dir: Optional[str] = None,
32
+ load_from_disk: bool = True,
33
+ ) -> None:
34
+ """Initialize the cached Flux text encoder step.
35
+
36
+ Args:
37
+ use_cache: Whether to enable caching of prompt embeddings. Defaults to True.
38
+ cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache.
39
+ load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True.
40
+ """
41
+ super().__init__()
42
+ self.cache = {} if use_cache else None
43
+ if use_cache:
44
+ self.cache_dir = cache_dir or os.path.join(
45
+ os.path.expanduser("~"), ".cache", "flux_prompt_cache"
46
+ )
47
+ os.makedirs(self.cache_dir, exist_ok=True)
48
+ else:
49
+ self.cache_dir = None
50
+
51
+ # Load existing cache if requested
52
+ if load_from_disk and use_cache:
53
+ self.load_cache_from_disk()
54
+
55
+ @property
56
+ def description(self) -> str:
57
+ return "Text Encoder step that generate text_embeddings to guide the video generation"
58
+
59
+ @property
60
+ def expected_components(self):
61
+ return [
62
+ ComponentSpec("text_encoder", CLIPTextModel),
63
+ ComponentSpec("tokenizer", CLIPTokenizer),
64
+ ComponentSpec("text_encoder_2", T5EncoderModel),
65
+ ComponentSpec("tokenizer_2", T5TokenizerFast),
66
+ ]
67
+
68
+ @property
69
+ def expected_configs(self):
70
+ return []
71
+
72
+ @property
73
+ def inputs(self) -> List[InputParam]:
74
+ return [
75
+ InputParam("prompt"),
76
+ InputParam("prompt_2"),
77
+ InputParam("joint_attention_kwargs"),
78
+ ]
79
+
80
+ @property
81
+ def intermediate_outputs(self):
82
+ return [
83
+ OutputParam(
84
+ "prompt_embeds",
85
+ type_hint=torch.Tensor,
86
+ description="text embeddings used to guide the image generation",
87
+ ),
88
+ OutputParam(
89
+ "pooled_prompt_embeds",
90
+ type_hint=torch.Tensor,
91
+ description="pooled text embeddings used to guide the image generation",
92
+ ),
93
+ OutputParam(
94
+ "text_ids",
95
+ type_hint=torch.Tensor,
96
+ description="ids from the text sequence for RoPE",
97
+ ),
98
+ ]
99
+
100
+ @staticmethod
101
+ def check_inputs(block_state):
102
+ for prompt in [block_state.prompt, block_state.prompt_2]:
103
+ if prompt is not None and (
104
+ not isinstance(prompt, str) and not isinstance(prompt, list)
105
+ ):
106
+ raise ValueError(
107
+ f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}"
108
+ )
109
+
110
+ def save_cache_to_disk(self):
111
+ """Save the current cache to disk as a safetensors file."""
112
+ if not self.cache or not self.cache_dir:
113
+ return
114
+
115
+ cache_file = os.path.join(self.cache_dir, "cache.safetensors")
116
+
117
+ # Prepare tensors dict for safetensors
118
+ tensors_to_save = {}
119
+ for key, tensor in self.cache.items():
120
+ # Ensure tensor is on CPU before saving
121
+ cpu_tensor = (
122
+ tensor.cpu() if tensor.device != torch.device("cpu") else tensor
123
+ )
124
+ tensors_to_save[key] = cpu_tensor
125
+
126
+ # Save tensors
127
+ save_file(tensors_to_save, cache_file)
128
+ logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}")
129
+
130
+ def load_cache_from_disk(self):
131
+ """Load cache from disk using memory-mapped safetensors."""
132
+ if not self.cache_dir or self.cache is None:
133
+ return
134
+
135
+ cache_file = os.path.join(self.cache_dir, "cache.safetensors")
136
+
137
+ if not os.path.exists(cache_file):
138
+ return
139
+
140
+ try:
141
+ # Open safetensors file in context manager
142
+ with safe_open(cache_file, framework="pt", device="cpu") as f:
143
+ loaded_count = 0
144
+ for key in f.keys():
145
+ self.cache[key] = f.get_tensor(key)
146
+ loaded_count += 1
147
+
148
+ logger.debug(
149
+ f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)"
150
+ )
151
+ except Exception as e:
152
+ logger.warning(f"Failed to load cache from disk: {e}")
153
+
154
+ def clear_cache_from_disk(self):
155
+ """Clear cached safetensors file from disk."""
156
+ if not self.cache_dir:
157
+ return
158
+
159
+ cache_file = os.path.join(self.cache_dir, "cache.safetensors")
160
+ if os.path.exists(cache_file):
161
+ os.remove(cache_file)
162
+ logger.info(f"Cleared cache file: {cache_file}")
163
+
164
+ # Also clear the in-memory cache
165
+ if self.cache:
166
+ self.cache.clear()
167
+
168
+ def get_cache_size(self):
169
+ """Get the current cache size in MB."""
170
+ if not self.cache_dir:
171
+ return 0
172
+
173
+ cache_file = os.path.join(self.cache_dir, "cache.safetensors")
174
+ if os.path.exists(cache_file):
175
+ return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB
176
+ return 0
177
+
178
+ @staticmethod
179
+ def _to_cache_key(prompt: str) -> str:
180
+ """Generate a hash key for a single prompt string."""
181
+ return hashlib.sha256(prompt.encode()).hexdigest()
182
+
183
+ @staticmethod
184
+ def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None):
185
+ """Split prompts into cached and new, returning indices for reconstruction.
186
+
187
+ Args:
188
+ prompts: List of prompt strings to check against cache.
189
+ cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
190
+ cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
191
+ device: Optional device to move cached tensors to.
192
+
193
+ Returns:
194
+ tuple: (cached_embeds, prompts_to_encode, prompt_indices)
195
+ - cached_embeds: List of (idx, embedding) tuples for cached prompts
196
+ - prompts_to_encode: List of prompts that need encoding
197
+ - prompt_indices: List of original indices for prompts_to_encode
198
+ """
199
+ cached_embeds = []
200
+ prompts_to_encode = []
201
+ prompt_indices = []
202
+
203
+ for idx, prompt in enumerate(prompts):
204
+ cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix)
205
+ if (
206
+ cache_instance
207
+ and cache_instance.cache
208
+ and cache_key in cache_instance.cache
209
+ ):
210
+ cached_tensor = cache_instance.cache[cache_key]
211
+ # Move tensor to the correct device if specified
212
+ if device is not None and cached_tensor.device != device:
213
+ cached_tensor = cached_tensor.to(device)
214
+ cached_embeds.append((idx, cached_tensor))
215
+ else:
216
+ prompts_to_encode.append(prompt)
217
+ prompt_indices.append(idx)
218
+
219
+ return cached_embeds, prompts_to_encode, prompt_indices
220
+
221
+ @staticmethod
222
+ def _cache_prompt_embeds(
223
+ prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix
224
+ ):
225
+ """Store newly computed embeddings in cache and save to disk.
226
+
227
+ Args:
228
+ prompts: Original full list of prompts.
229
+ prompt_indices: Indices of newly encoded prompts in the original list.
230
+ prompt_embeds: Newly computed embeddings tensor.
231
+ cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
232
+ cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
233
+ """
234
+ if not cache_instance or cache_instance.cache is None:
235
+ return
236
+
237
+ for i, idx in enumerate(prompt_indices):
238
+ cache_key = CachedFluxTextEncoderStep._to_cache_key(
239
+ prompts[idx] + cache_suffix
240
+ )
241
+ # Store in memory cache on CPU to save GPU memory
242
+ tensor_slice = prompt_embeds[i : i + 1]
243
+ cache_instance.cache[cache_key] = tensor_slice
244
+
245
+ # Save updated cache to disk
246
+ cache_instance.save_cache_to_disk()
247
+
248
+ @staticmethod
249
+ def _merge_cached_prompt_embeds(
250
+ cached_embeds, prompt_indices, prompt_embeds, batch_size
251
+ ):
252
+ """Merge cached and newly computed embeddings back into original batch order.
253
+
254
+ Args:
255
+ cached_embeds: List of (idx, embedding) tuples from cache.
256
+ prompt_indices: Indices where new embeddings should be placed.
257
+ prompt_embeds: Newly computed embeddings tensor, or None if all cached.
258
+ batch_size: Total batch size for output tensor.
259
+
260
+ Returns:
261
+ torch.Tensor: Combined embeddings tensor in correct batch order.
262
+ """
263
+ all_embeds = [None] * batch_size
264
+
265
+ # Place cached embeddings
266
+ for idx, embed in cached_embeds:
267
+ all_embeds[idx] = embed
268
+
269
+ # Place new embeddings
270
+ if prompt_embeds is not None:
271
+ for i, idx in enumerate(prompt_indices):
272
+ all_embeds[idx] = prompt_embeds[i : i + 1]
273
+
274
+ return torch.cat(all_embeds, dim=0)
275
+
276
+ @staticmethod
277
+ def _get_t5_prompt_embeds(
278
+ components,
279
+ prompt: Union[str, List[str]] = None,
280
+ num_images_per_prompt: int = 1,
281
+ max_sequence_length: int = 512,
282
+ device: torch.device = None,
283
+ cache_instance=None,
284
+ ):
285
+ """Encode prompts using T5 text encoder with caching support.
286
+
287
+ Args:
288
+ components: Pipeline components containing T5 encoder and tokenizer.
289
+ prompt: Prompt(s) to encode.
290
+ num_images_per_prompt: Number of images per prompt for duplication.
291
+ max_sequence_length: Maximum sequence length for tokenization.
292
+ device: Device to place tensors on.
293
+ cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
294
+
295
+ Returns:
296
+ torch.Tensor: T5 prompt embeddings ready for diffusion model.
297
+ """
298
+ dtype = components.text_encoder_2.dtype
299
+ prompt = [prompt] if isinstance(prompt, str) else prompt
300
+ batch_size = len(prompt)
301
+
302
+ cached_embeds, prompts_to_encode, prompt_indices = (
303
+ CachedFluxTextEncoderStep._get_cached_prompt_embeds(
304
+ prompt, cache_instance, "_t5", device
305
+ )
306
+ )
307
+
308
+ if not prompts_to_encode:
309
+ prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
310
+ cached_embeds, prompt_indices, None, batch_size
311
+ )
312
+ _, seq_len, _ = prompt_embeds.shape
313
+
314
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
315
+ prompt_embeds = prompt_embeds.view(
316
+ batch_size * num_images_per_prompt, seq_len, -1
317
+ )
318
+ return prompt_embeds
319
+
320
+ if isinstance(components, TextualInversionLoaderMixin):
321
+ prompts_to_encode = components.maybe_convert_prompt(
322
+ prompts_to_encode, components.tokenizer_2
323
+ )
324
+
325
+ text_inputs = components.tokenizer_2(
326
+ prompts_to_encode,
327
+ padding="max_length",
328
+ max_length=max_sequence_length,
329
+ truncation=True,
330
+ return_length=False,
331
+ return_overflowing_tokens=False,
332
+ return_tensors="pt",
333
+ )
334
+ text_input_ids = text_inputs.input_ids
335
+
336
+ # Check for truncation
337
+ untruncated_ids = components.tokenizer_2(
338
+ prompts_to_encode, padding="longest", return_tensors="pt"
339
+ ).input_ids
340
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
341
+ text_input_ids, untruncated_ids
342
+ ):
343
+ removed_text = components.tokenizer_2.batch_decode(
344
+ untruncated_ids[:, max_sequence_length - 1 : -1]
345
+ )
346
+ logger.warning(
347
+ "The following part of your input was truncated because `max_sequence_length` is set to "
348
+ f" {max_sequence_length} tokens: {removed_text}"
349
+ )
350
+
351
+ prompt_embeds = components.text_encoder_2(
352
+ text_input_ids.to(device), output_hidden_states=False
353
+ )[0]
354
+
355
+ CachedFluxTextEncoderStep._cache_prompt_embeds(
356
+ prompt, prompt_indices, prompt_embeds, cache_instance, "_t5"
357
+ )
358
+
359
+ prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
360
+ cached_embeds, prompt_indices, prompt_embeds, batch_size
361
+ )
362
+ _, seq_len, _ = prompt_embeds.shape
363
+
364
+ # Duplicate for num_images_per_prompt
365
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ batch_size * num_images_per_prompt, seq_len, -1
368
+ )
369
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
370
+
371
+ return prompt_embeds
372
+
373
+ @staticmethod
374
+ def _get_clip_prompt_embeds(
375
+ components,
376
+ prompt: Union[str, List[str]] = None,
377
+ num_images_per_prompt: int = 1,
378
+ device: torch.device = None,
379
+ cache_instance=None,
380
+ ):
381
+ """Encode prompts using CLIP text encoder with caching support.
382
+
383
+ Args:
384
+ components: Pipeline components containing CLIP encoder and tokenizer.
385
+ prompt: Prompt(s) to encode.
386
+ num_images_per_prompt: Number of images per prompt for duplication.
387
+ device: Device to place tensors on.
388
+ cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
389
+
390
+ Returns:
391
+ torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model.
392
+ """
393
+ prompt = [prompt] if isinstance(prompt, str) else prompt
394
+ batch_size = len(prompt)
395
+
396
+ # Split cached and new prompts
397
+ cached_embeds, prompts_to_encode, prompt_indices = (
398
+ CachedFluxTextEncoderStep._get_cached_prompt_embeds(
399
+ prompt, cache_instance, "_clip", device
400
+ )
401
+ )
402
+
403
+ # Early return if all prompts are cached
404
+ if not prompts_to_encode:
405
+ prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
406
+ cached_embeds, prompt_indices, None, batch_size
407
+ )
408
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
409
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
410
+ return prompt_embeds
411
+
412
+ if prompts_to_encode:
413
+ if isinstance(components, TextualInversionLoaderMixin):
414
+ prompts_to_encode = components.maybe_convert_prompt(
415
+ prompts_to_encode, components.tokenizer
416
+ )
417
+
418
+ text_inputs = components.tokenizer(
419
+ prompts_to_encode,
420
+ padding="max_length",
421
+ max_length=components.tokenizer.model_max_length,
422
+ truncation=True,
423
+ return_overflowing_tokens=False,
424
+ return_length=False,
425
+ return_tensors="pt",
426
+ )
427
+
428
+ text_input_ids = text_inputs.input_ids
429
+ tokenizer_max_length = components.tokenizer.model_max_length
430
+ untruncated_ids = components.tokenizer(
431
+ prompts_to_encode, padding="longest", return_tensors="pt"
432
+ ).input_ids
433
+
434
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
435
+ -1
436
+ ] and not torch.equal(text_input_ids, untruncated_ids):
437
+ removed_text = components.tokenizer.batch_decode(
438
+ untruncated_ids[:, tokenizer_max_length - 1 : -1]
439
+ )
440
+ logger.warning(
441
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
442
+ f" {tokenizer_max_length} tokens: {removed_text}"
443
+ )
444
+
445
+ prompt_embeds = components.text_encoder(
446
+ text_input_ids.to(device), output_hidden_states=False
447
+ )
448
+
449
+ # Use pooled output of CLIPTextModel
450
+ prompt_embeds = prompt_embeds.pooler_output
451
+ prompt_embeds = prompt_embeds.to(
452
+ dtype=components.text_encoder.dtype, device=device
453
+ )
454
+
455
+ # Cache the new embeddings
456
+ CachedFluxTextEncoderStep._cache_prompt_embeds(
457
+ prompt, prompt_indices, prompt_embeds, cache_instance, "_clip"
458
+ )
459
+
460
+ # Combine cached and newly encoded embeddings in correct order
461
+ prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
462
+ cached_embeds,
463
+ prompt_indices,
464
+ prompt_embeds if prompts_to_encode else None,
465
+ batch_size,
466
+ )
467
+
468
+ # Duplicate for num_images_per_prompt
469
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
470
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
471
+
472
+ return prompt_embeds
473
+
474
+ @staticmethod
475
+ def encode_prompt(
476
+ components,
477
+ prompt: Union[str, List[str]] = None,
478
+ prompt_2: Union[str, List[str]] = None,
479
+ device: Optional[torch.device] = None,
480
+ num_images_per_prompt: int = 1,
481
+ prompt_embeds: Optional[torch.FloatTensor] = None,
482
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
483
+ max_sequence_length: int = 512,
484
+ lora_scale: Optional[float] = None,
485
+ cache_instance: Optional["CachedFluxTextEncoderStep"] = None,
486
+ ):
487
+ r"""
488
+ Encodes the prompt into text encoder hidden states.
489
+
490
+ Args:
491
+ prompt (`str` or `List[str]`, *optional*):
492
+ prompt to be encoded
493
+ prompt_2 (`str` or `List[str]`, *optional*):
494
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
495
+ used in all text-encoders
496
+ device: (`torch.device`):
497
+ torch device
498
+ num_images_per_prompt (`int`):
499
+ number of images that should be generated per prompt
500
+ prompt_embeds (`torch.FloatTensor`, *optional*):
501
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
502
+ provided, text embeddings will be generated from `prompt` input argument.
503
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
504
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
505
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
506
+ lora_scale (`float`, *optional*):
507
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
508
+ """
509
+ device = device or components._execution_device
510
+
511
+ # set lora scale so that monkey patched LoRA
512
+ # function of text encoder can correctly access it
513
+ if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
514
+ components._lora_scale = lora_scale
515
+
516
+ # dynamically adjust the LoRA scale
517
+ if components.text_encoder is not None and USE_PEFT_BACKEND:
518
+ scale_lora_layers(components.text_encoder, lora_scale)
519
+ if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
520
+ scale_lora_layers(components.text_encoder_2, lora_scale)
521
+
522
+ prompt = [prompt] if isinstance(prompt, str) else prompt
523
+
524
+ if prompt_embeds is None:
525
+ prompt_2 = prompt_2 or prompt
526
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
527
+
528
+ # We only use the pooled prompt output from the CLIPTextModel
529
+ pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds(
530
+ components,
531
+ prompt=prompt,
532
+ device=device,
533
+ num_images_per_prompt=num_images_per_prompt,
534
+ cache_instance=cache_instance,
535
+ )
536
+ prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds(
537
+ components,
538
+ prompt=prompt_2,
539
+ num_images_per_prompt=num_images_per_prompt,
540
+ max_sequence_length=max_sequence_length,
541
+ device=device,
542
+ cache_instance=cache_instance,
543
+ )
544
+
545
+ if components.text_encoder is not None:
546
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
547
+ # Retrieve the original scale by scaling back the LoRA layers
548
+ unscale_lora_layers(components.text_encoder, lora_scale)
549
+
550
+ if components.text_encoder_2 is not None:
551
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
552
+ # Retrieve the original scale by scaling back the LoRA layers
553
+ unscale_lora_layers(components.text_encoder_2, lora_scale)
554
+
555
+ dtype = (
556
+ components.text_encoder.dtype
557
+ if components.text_encoder is not None
558
+ else torch.bfloat16
559
+ )
560
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
561
+
562
+ return prompt_embeds, pooled_prompt_embeds, text_ids
563
+
564
+ @torch.no_grad()
565
+ def __call__(
566
+ self, components: FluxModularPipeline, state: PipelineState
567
+ ) -> PipelineState:
568
+ # Get inputs and intermediates
569
+ block_state = self.get_block_state(state)
570
+ self.check_inputs(block_state)
571
+
572
+ block_state.device = components._execution_device
573
+
574
+ # Encode input prompt
575
+ block_state.text_encoder_lora_scale = (
576
+ block_state.joint_attention_kwargs.get("scale", None)
577
+ if block_state.joint_attention_kwargs is not None
578
+ else None
579
+ )
580
+ (
581
+ block_state.prompt_embeds,
582
+ block_state.pooled_prompt_embeds,
583
+ block_state.text_ids,
584
+ ) = self.encode_prompt(
585
+ components,
586
+ prompt=block_state.prompt,
587
+ prompt_2=None,
588
+ prompt_embeds=None,
589
+ pooled_prompt_embeds=None,
590
+ device=block_state.device,
591
+ num_images_per_prompt=1, # TODO: hardcoded for now.
592
+ max_sequence_length=512,
593
+ lora_scale=block_state.text_encoder_lora_scale,
594
+ cache_instance=self
595
+ if self.cache is not None
596
+ else None, # Pass self as cache_instance
597
+ )
598
+
599
+ # Add outputs
600
+ self.set_block_state(state, block_state)
601
+ return components, state
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CachedFluxTextEncoderStep",
3
+ "_diffusers_version": "0.35.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "block.CachedFluxTextEncoderStep"
6
+ }
7
+ }