Upload folder using huggingface_hub
Browse files- README.md +32 -0
- block.py +601 -0
- modular_config.json +7 -0
README.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cached Flux Prompt Encoding
|
2 |
+
|
3 |
+
This is a custom block designed to cache input prompts for the Flux model. Prompts encoded prompts are stored in a safetensors file using the hashed prompt string as the key. Prompts existing in the cache are loaded directly from the file, while new prompts are encoded and added to the cache.
|
4 |
+
|
5 |
+
# How to use
|
6 |
+
|
7 |
+
```python
|
8 |
+
import torch
|
9 |
+
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
10 |
+
from diffusers.modular_pipelines.flux.modular_blocks import TEXT2IMAGE_BLOCKS
|
11 |
+
|
12 |
+
prompt_encoder_block = ModularPipelineBlocks.from_pretrained(
|
13 |
+
"diffusers/flux-cached-prompt-encoder-custom-block",
|
14 |
+
trust_remote_code=True
|
15 |
+
)
|
16 |
+
|
17 |
+
blocks = TEXT2IMAGE_BLOCKS.copy().insert("text_encoder", prompt_encoder_block, 0)
|
18 |
+
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks)
|
19 |
+
|
20 |
+
repo_id = "diffusers/modular-FLUX.1-dev"
|
21 |
+
pipe = blocks.init_pipeline(repo_id)
|
22 |
+
pipe.load_components(torch_dtype=torch.bfloat16, device_map="cuda")
|
23 |
+
|
24 |
+
output = pipe(
|
25 |
+
prompt=prompt,
|
26 |
+
num_inference_steps=35,
|
27 |
+
guidance_scale=3.5,
|
28 |
+
output_type="pil",
|
29 |
+
)
|
30 |
+
image = output.values['image']
|
31 |
+
```
|
32 |
+
|
block.py
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from diffusers import FluxModularPipeline, ModularPipelineBlocks
|
7 |
+
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
8 |
+
from diffusers.modular_pipelines import PipelineState
|
9 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
10 |
+
ComponentSpec,
|
11 |
+
InputParam,
|
12 |
+
OutputParam,
|
13 |
+
)
|
14 |
+
from diffusers.utils import (
|
15 |
+
USE_PEFT_BACKEND,
|
16 |
+
logger,
|
17 |
+
scale_lora_layers,
|
18 |
+
unscale_lora_layers,
|
19 |
+
)
|
20 |
+
from safetensors import safe_open
|
21 |
+
from safetensors.torch import save_file
|
22 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
23 |
+
|
24 |
+
|
25 |
+
class CachedFluxTextEncoderStep(ModularPipelineBlocks):
|
26 |
+
model_name = "flux"
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
use_cache: bool = True,
|
31 |
+
cache_dir: Optional[str] = None,
|
32 |
+
load_from_disk: bool = True,
|
33 |
+
) -> None:
|
34 |
+
"""Initialize the cached Flux text encoder step.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
use_cache: Whether to enable caching of prompt embeddings. Defaults to True.
|
38 |
+
cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache.
|
39 |
+
load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True.
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
self.cache = {} if use_cache else None
|
43 |
+
if use_cache:
|
44 |
+
self.cache_dir = cache_dir or os.path.join(
|
45 |
+
os.path.expanduser("~"), ".cache", "flux_prompt_cache"
|
46 |
+
)
|
47 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
48 |
+
else:
|
49 |
+
self.cache_dir = None
|
50 |
+
|
51 |
+
# Load existing cache if requested
|
52 |
+
if load_from_disk and use_cache:
|
53 |
+
self.load_cache_from_disk()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def description(self) -> str:
|
57 |
+
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
58 |
+
|
59 |
+
@property
|
60 |
+
def expected_components(self):
|
61 |
+
return [
|
62 |
+
ComponentSpec("text_encoder", CLIPTextModel),
|
63 |
+
ComponentSpec("tokenizer", CLIPTokenizer),
|
64 |
+
ComponentSpec("text_encoder_2", T5EncoderModel),
|
65 |
+
ComponentSpec("tokenizer_2", T5TokenizerFast),
|
66 |
+
]
|
67 |
+
|
68 |
+
@property
|
69 |
+
def expected_configs(self):
|
70 |
+
return []
|
71 |
+
|
72 |
+
@property
|
73 |
+
def inputs(self) -> List[InputParam]:
|
74 |
+
return [
|
75 |
+
InputParam("prompt"),
|
76 |
+
InputParam("prompt_2"),
|
77 |
+
InputParam("joint_attention_kwargs"),
|
78 |
+
]
|
79 |
+
|
80 |
+
@property
|
81 |
+
def intermediate_outputs(self):
|
82 |
+
return [
|
83 |
+
OutputParam(
|
84 |
+
"prompt_embeds",
|
85 |
+
type_hint=torch.Tensor,
|
86 |
+
description="text embeddings used to guide the image generation",
|
87 |
+
),
|
88 |
+
OutputParam(
|
89 |
+
"pooled_prompt_embeds",
|
90 |
+
type_hint=torch.Tensor,
|
91 |
+
description="pooled text embeddings used to guide the image generation",
|
92 |
+
),
|
93 |
+
OutputParam(
|
94 |
+
"text_ids",
|
95 |
+
type_hint=torch.Tensor,
|
96 |
+
description="ids from the text sequence for RoPE",
|
97 |
+
),
|
98 |
+
]
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def check_inputs(block_state):
|
102 |
+
for prompt in [block_state.prompt, block_state.prompt_2]:
|
103 |
+
if prompt is not None and (
|
104 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
105 |
+
):
|
106 |
+
raise ValueError(
|
107 |
+
f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}"
|
108 |
+
)
|
109 |
+
|
110 |
+
def save_cache_to_disk(self):
|
111 |
+
"""Save the current cache to disk as a safetensors file."""
|
112 |
+
if not self.cache or not self.cache_dir:
|
113 |
+
return
|
114 |
+
|
115 |
+
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
|
116 |
+
|
117 |
+
# Prepare tensors dict for safetensors
|
118 |
+
tensors_to_save = {}
|
119 |
+
for key, tensor in self.cache.items():
|
120 |
+
# Ensure tensor is on CPU before saving
|
121 |
+
cpu_tensor = (
|
122 |
+
tensor.cpu() if tensor.device != torch.device("cpu") else tensor
|
123 |
+
)
|
124 |
+
tensors_to_save[key] = cpu_tensor
|
125 |
+
|
126 |
+
# Save tensors
|
127 |
+
save_file(tensors_to_save, cache_file)
|
128 |
+
logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}")
|
129 |
+
|
130 |
+
def load_cache_from_disk(self):
|
131 |
+
"""Load cache from disk using memory-mapped safetensors."""
|
132 |
+
if not self.cache_dir or self.cache is None:
|
133 |
+
return
|
134 |
+
|
135 |
+
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
|
136 |
+
|
137 |
+
if not os.path.exists(cache_file):
|
138 |
+
return
|
139 |
+
|
140 |
+
try:
|
141 |
+
# Open safetensors file in context manager
|
142 |
+
with safe_open(cache_file, framework="pt", device="cpu") as f:
|
143 |
+
loaded_count = 0
|
144 |
+
for key in f.keys():
|
145 |
+
self.cache[key] = f.get_tensor(key)
|
146 |
+
loaded_count += 1
|
147 |
+
|
148 |
+
logger.debug(
|
149 |
+
f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)"
|
150 |
+
)
|
151 |
+
except Exception as e:
|
152 |
+
logger.warning(f"Failed to load cache from disk: {e}")
|
153 |
+
|
154 |
+
def clear_cache_from_disk(self):
|
155 |
+
"""Clear cached safetensors file from disk."""
|
156 |
+
if not self.cache_dir:
|
157 |
+
return
|
158 |
+
|
159 |
+
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
|
160 |
+
if os.path.exists(cache_file):
|
161 |
+
os.remove(cache_file)
|
162 |
+
logger.info(f"Cleared cache file: {cache_file}")
|
163 |
+
|
164 |
+
# Also clear the in-memory cache
|
165 |
+
if self.cache:
|
166 |
+
self.cache.clear()
|
167 |
+
|
168 |
+
def get_cache_size(self):
|
169 |
+
"""Get the current cache size in MB."""
|
170 |
+
if not self.cache_dir:
|
171 |
+
return 0
|
172 |
+
|
173 |
+
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
|
174 |
+
if os.path.exists(cache_file):
|
175 |
+
return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB
|
176 |
+
return 0
|
177 |
+
|
178 |
+
@staticmethod
|
179 |
+
def _to_cache_key(prompt: str) -> str:
|
180 |
+
"""Generate a hash key for a single prompt string."""
|
181 |
+
return hashlib.sha256(prompt.encode()).hexdigest()
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None):
|
185 |
+
"""Split prompts into cached and new, returning indices for reconstruction.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
prompts: List of prompt strings to check against cache.
|
189 |
+
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
|
190 |
+
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
|
191 |
+
device: Optional device to move cached tensors to.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
tuple: (cached_embeds, prompts_to_encode, prompt_indices)
|
195 |
+
- cached_embeds: List of (idx, embedding) tuples for cached prompts
|
196 |
+
- prompts_to_encode: List of prompts that need encoding
|
197 |
+
- prompt_indices: List of original indices for prompts_to_encode
|
198 |
+
"""
|
199 |
+
cached_embeds = []
|
200 |
+
prompts_to_encode = []
|
201 |
+
prompt_indices = []
|
202 |
+
|
203 |
+
for idx, prompt in enumerate(prompts):
|
204 |
+
cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix)
|
205 |
+
if (
|
206 |
+
cache_instance
|
207 |
+
and cache_instance.cache
|
208 |
+
and cache_key in cache_instance.cache
|
209 |
+
):
|
210 |
+
cached_tensor = cache_instance.cache[cache_key]
|
211 |
+
# Move tensor to the correct device if specified
|
212 |
+
if device is not None and cached_tensor.device != device:
|
213 |
+
cached_tensor = cached_tensor.to(device)
|
214 |
+
cached_embeds.append((idx, cached_tensor))
|
215 |
+
else:
|
216 |
+
prompts_to_encode.append(prompt)
|
217 |
+
prompt_indices.append(idx)
|
218 |
+
|
219 |
+
return cached_embeds, prompts_to_encode, prompt_indices
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def _cache_prompt_embeds(
|
223 |
+
prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix
|
224 |
+
):
|
225 |
+
"""Store newly computed embeddings in cache and save to disk.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
prompts: Original full list of prompts.
|
229 |
+
prompt_indices: Indices of newly encoded prompts in the original list.
|
230 |
+
prompt_embeds: Newly computed embeddings tensor.
|
231 |
+
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
|
232 |
+
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
|
233 |
+
"""
|
234 |
+
if not cache_instance or cache_instance.cache is None:
|
235 |
+
return
|
236 |
+
|
237 |
+
for i, idx in enumerate(prompt_indices):
|
238 |
+
cache_key = CachedFluxTextEncoderStep._to_cache_key(
|
239 |
+
prompts[idx] + cache_suffix
|
240 |
+
)
|
241 |
+
# Store in memory cache on CPU to save GPU memory
|
242 |
+
tensor_slice = prompt_embeds[i : i + 1]
|
243 |
+
cache_instance.cache[cache_key] = tensor_slice
|
244 |
+
|
245 |
+
# Save updated cache to disk
|
246 |
+
cache_instance.save_cache_to_disk()
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def _merge_cached_prompt_embeds(
|
250 |
+
cached_embeds, prompt_indices, prompt_embeds, batch_size
|
251 |
+
):
|
252 |
+
"""Merge cached and newly computed embeddings back into original batch order.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
cached_embeds: List of (idx, embedding) tuples from cache.
|
256 |
+
prompt_indices: Indices where new embeddings should be placed.
|
257 |
+
prompt_embeds: Newly computed embeddings tensor, or None if all cached.
|
258 |
+
batch_size: Total batch size for output tensor.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
torch.Tensor: Combined embeddings tensor in correct batch order.
|
262 |
+
"""
|
263 |
+
all_embeds = [None] * batch_size
|
264 |
+
|
265 |
+
# Place cached embeddings
|
266 |
+
for idx, embed in cached_embeds:
|
267 |
+
all_embeds[idx] = embed
|
268 |
+
|
269 |
+
# Place new embeddings
|
270 |
+
if prompt_embeds is not None:
|
271 |
+
for i, idx in enumerate(prompt_indices):
|
272 |
+
all_embeds[idx] = prompt_embeds[i : i + 1]
|
273 |
+
|
274 |
+
return torch.cat(all_embeds, dim=0)
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def _get_t5_prompt_embeds(
|
278 |
+
components,
|
279 |
+
prompt: Union[str, List[str]] = None,
|
280 |
+
num_images_per_prompt: int = 1,
|
281 |
+
max_sequence_length: int = 512,
|
282 |
+
device: torch.device = None,
|
283 |
+
cache_instance=None,
|
284 |
+
):
|
285 |
+
"""Encode prompts using T5 text encoder with caching support.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
components: Pipeline components containing T5 encoder and tokenizer.
|
289 |
+
prompt: Prompt(s) to encode.
|
290 |
+
num_images_per_prompt: Number of images per prompt for duplication.
|
291 |
+
max_sequence_length: Maximum sequence length for tokenization.
|
292 |
+
device: Device to place tensors on.
|
293 |
+
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
torch.Tensor: T5 prompt embeddings ready for diffusion model.
|
297 |
+
"""
|
298 |
+
dtype = components.text_encoder_2.dtype
|
299 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
300 |
+
batch_size = len(prompt)
|
301 |
+
|
302 |
+
cached_embeds, prompts_to_encode, prompt_indices = (
|
303 |
+
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
|
304 |
+
prompt, cache_instance, "_t5", device
|
305 |
+
)
|
306 |
+
)
|
307 |
+
|
308 |
+
if not prompts_to_encode:
|
309 |
+
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
|
310 |
+
cached_embeds, prompt_indices, None, batch_size
|
311 |
+
)
|
312 |
+
_, seq_len, _ = prompt_embeds.shape
|
313 |
+
|
314 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
315 |
+
prompt_embeds = prompt_embeds.view(
|
316 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
317 |
+
)
|
318 |
+
return prompt_embeds
|
319 |
+
|
320 |
+
if isinstance(components, TextualInversionLoaderMixin):
|
321 |
+
prompts_to_encode = components.maybe_convert_prompt(
|
322 |
+
prompts_to_encode, components.tokenizer_2
|
323 |
+
)
|
324 |
+
|
325 |
+
text_inputs = components.tokenizer_2(
|
326 |
+
prompts_to_encode,
|
327 |
+
padding="max_length",
|
328 |
+
max_length=max_sequence_length,
|
329 |
+
truncation=True,
|
330 |
+
return_length=False,
|
331 |
+
return_overflowing_tokens=False,
|
332 |
+
return_tensors="pt",
|
333 |
+
)
|
334 |
+
text_input_ids = text_inputs.input_ids
|
335 |
+
|
336 |
+
# Check for truncation
|
337 |
+
untruncated_ids = components.tokenizer_2(
|
338 |
+
prompts_to_encode, padding="longest", return_tensors="pt"
|
339 |
+
).input_ids
|
340 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
341 |
+
text_input_ids, untruncated_ids
|
342 |
+
):
|
343 |
+
removed_text = components.tokenizer_2.batch_decode(
|
344 |
+
untruncated_ids[:, max_sequence_length - 1 : -1]
|
345 |
+
)
|
346 |
+
logger.warning(
|
347 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
348 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
349 |
+
)
|
350 |
+
|
351 |
+
prompt_embeds = components.text_encoder_2(
|
352 |
+
text_input_ids.to(device), output_hidden_states=False
|
353 |
+
)[0]
|
354 |
+
|
355 |
+
CachedFluxTextEncoderStep._cache_prompt_embeds(
|
356 |
+
prompt, prompt_indices, prompt_embeds, cache_instance, "_t5"
|
357 |
+
)
|
358 |
+
|
359 |
+
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
|
360 |
+
cached_embeds, prompt_indices, prompt_embeds, batch_size
|
361 |
+
)
|
362 |
+
_, seq_len, _ = prompt_embeds.shape
|
363 |
+
|
364 |
+
# Duplicate for num_images_per_prompt
|
365 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
366 |
+
prompt_embeds = prompt_embeds.view(
|
367 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
368 |
+
)
|
369 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
370 |
+
|
371 |
+
return prompt_embeds
|
372 |
+
|
373 |
+
@staticmethod
|
374 |
+
def _get_clip_prompt_embeds(
|
375 |
+
components,
|
376 |
+
prompt: Union[str, List[str]] = None,
|
377 |
+
num_images_per_prompt: int = 1,
|
378 |
+
device: torch.device = None,
|
379 |
+
cache_instance=None,
|
380 |
+
):
|
381 |
+
"""Encode prompts using CLIP text encoder with caching support.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
components: Pipeline components containing CLIP encoder and tokenizer.
|
385 |
+
prompt: Prompt(s) to encode.
|
386 |
+
num_images_per_prompt: Number of images per prompt for duplication.
|
387 |
+
device: Device to place tensors on.
|
388 |
+
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model.
|
392 |
+
"""
|
393 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
394 |
+
batch_size = len(prompt)
|
395 |
+
|
396 |
+
# Split cached and new prompts
|
397 |
+
cached_embeds, prompts_to_encode, prompt_indices = (
|
398 |
+
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
|
399 |
+
prompt, cache_instance, "_clip", device
|
400 |
+
)
|
401 |
+
)
|
402 |
+
|
403 |
+
# Early return if all prompts are cached
|
404 |
+
if not prompts_to_encode:
|
405 |
+
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
|
406 |
+
cached_embeds, prompt_indices, None, batch_size
|
407 |
+
)
|
408 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
409 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
410 |
+
return prompt_embeds
|
411 |
+
|
412 |
+
if prompts_to_encode:
|
413 |
+
if isinstance(components, TextualInversionLoaderMixin):
|
414 |
+
prompts_to_encode = components.maybe_convert_prompt(
|
415 |
+
prompts_to_encode, components.tokenizer
|
416 |
+
)
|
417 |
+
|
418 |
+
text_inputs = components.tokenizer(
|
419 |
+
prompts_to_encode,
|
420 |
+
padding="max_length",
|
421 |
+
max_length=components.tokenizer.model_max_length,
|
422 |
+
truncation=True,
|
423 |
+
return_overflowing_tokens=False,
|
424 |
+
return_length=False,
|
425 |
+
return_tensors="pt",
|
426 |
+
)
|
427 |
+
|
428 |
+
text_input_ids = text_inputs.input_ids
|
429 |
+
tokenizer_max_length = components.tokenizer.model_max_length
|
430 |
+
untruncated_ids = components.tokenizer(
|
431 |
+
prompts_to_encode, padding="longest", return_tensors="pt"
|
432 |
+
).input_ids
|
433 |
+
|
434 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
435 |
+
-1
|
436 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
437 |
+
removed_text = components.tokenizer.batch_decode(
|
438 |
+
untruncated_ids[:, tokenizer_max_length - 1 : -1]
|
439 |
+
)
|
440 |
+
logger.warning(
|
441 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
442 |
+
f" {tokenizer_max_length} tokens: {removed_text}"
|
443 |
+
)
|
444 |
+
|
445 |
+
prompt_embeds = components.text_encoder(
|
446 |
+
text_input_ids.to(device), output_hidden_states=False
|
447 |
+
)
|
448 |
+
|
449 |
+
# Use pooled output of CLIPTextModel
|
450 |
+
prompt_embeds = prompt_embeds.pooler_output
|
451 |
+
prompt_embeds = prompt_embeds.to(
|
452 |
+
dtype=components.text_encoder.dtype, device=device
|
453 |
+
)
|
454 |
+
|
455 |
+
# Cache the new embeddings
|
456 |
+
CachedFluxTextEncoderStep._cache_prompt_embeds(
|
457 |
+
prompt, prompt_indices, prompt_embeds, cache_instance, "_clip"
|
458 |
+
)
|
459 |
+
|
460 |
+
# Combine cached and newly encoded embeddings in correct order
|
461 |
+
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
|
462 |
+
cached_embeds,
|
463 |
+
prompt_indices,
|
464 |
+
prompt_embeds if prompts_to_encode else None,
|
465 |
+
batch_size,
|
466 |
+
)
|
467 |
+
|
468 |
+
# Duplicate for num_images_per_prompt
|
469 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
470 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
471 |
+
|
472 |
+
return prompt_embeds
|
473 |
+
|
474 |
+
@staticmethod
|
475 |
+
def encode_prompt(
|
476 |
+
components,
|
477 |
+
prompt: Union[str, List[str]] = None,
|
478 |
+
prompt_2: Union[str, List[str]] = None,
|
479 |
+
device: Optional[torch.device] = None,
|
480 |
+
num_images_per_prompt: int = 1,
|
481 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
482 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
483 |
+
max_sequence_length: int = 512,
|
484 |
+
lora_scale: Optional[float] = None,
|
485 |
+
cache_instance: Optional["CachedFluxTextEncoderStep"] = None,
|
486 |
+
):
|
487 |
+
r"""
|
488 |
+
Encodes the prompt into text encoder hidden states.
|
489 |
+
|
490 |
+
Args:
|
491 |
+
prompt (`str` or `List[str]`, *optional*):
|
492 |
+
prompt to be encoded
|
493 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
494 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
495 |
+
used in all text-encoders
|
496 |
+
device: (`torch.device`):
|
497 |
+
torch device
|
498 |
+
num_images_per_prompt (`int`):
|
499 |
+
number of images that should be generated per prompt
|
500 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
501 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
502 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
503 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
504 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
505 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
506 |
+
lora_scale (`float`, *optional*):
|
507 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
508 |
+
"""
|
509 |
+
device = device or components._execution_device
|
510 |
+
|
511 |
+
# set lora scale so that monkey patched LoRA
|
512 |
+
# function of text encoder can correctly access it
|
513 |
+
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
|
514 |
+
components._lora_scale = lora_scale
|
515 |
+
|
516 |
+
# dynamically adjust the LoRA scale
|
517 |
+
if components.text_encoder is not None and USE_PEFT_BACKEND:
|
518 |
+
scale_lora_layers(components.text_encoder, lora_scale)
|
519 |
+
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
520 |
+
scale_lora_layers(components.text_encoder_2, lora_scale)
|
521 |
+
|
522 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
523 |
+
|
524 |
+
if prompt_embeds is None:
|
525 |
+
prompt_2 = prompt_2 or prompt
|
526 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
527 |
+
|
528 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
529 |
+
pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds(
|
530 |
+
components,
|
531 |
+
prompt=prompt,
|
532 |
+
device=device,
|
533 |
+
num_images_per_prompt=num_images_per_prompt,
|
534 |
+
cache_instance=cache_instance,
|
535 |
+
)
|
536 |
+
prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds(
|
537 |
+
components,
|
538 |
+
prompt=prompt_2,
|
539 |
+
num_images_per_prompt=num_images_per_prompt,
|
540 |
+
max_sequence_length=max_sequence_length,
|
541 |
+
device=device,
|
542 |
+
cache_instance=cache_instance,
|
543 |
+
)
|
544 |
+
|
545 |
+
if components.text_encoder is not None:
|
546 |
+
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
547 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
548 |
+
unscale_lora_layers(components.text_encoder, lora_scale)
|
549 |
+
|
550 |
+
if components.text_encoder_2 is not None:
|
551 |
+
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
552 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
553 |
+
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
554 |
+
|
555 |
+
dtype = (
|
556 |
+
components.text_encoder.dtype
|
557 |
+
if components.text_encoder is not None
|
558 |
+
else torch.bfloat16
|
559 |
+
)
|
560 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
561 |
+
|
562 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
563 |
+
|
564 |
+
@torch.no_grad()
|
565 |
+
def __call__(
|
566 |
+
self, components: FluxModularPipeline, state: PipelineState
|
567 |
+
) -> PipelineState:
|
568 |
+
# Get inputs and intermediates
|
569 |
+
block_state = self.get_block_state(state)
|
570 |
+
self.check_inputs(block_state)
|
571 |
+
|
572 |
+
block_state.device = components._execution_device
|
573 |
+
|
574 |
+
# Encode input prompt
|
575 |
+
block_state.text_encoder_lora_scale = (
|
576 |
+
block_state.joint_attention_kwargs.get("scale", None)
|
577 |
+
if block_state.joint_attention_kwargs is not None
|
578 |
+
else None
|
579 |
+
)
|
580 |
+
(
|
581 |
+
block_state.prompt_embeds,
|
582 |
+
block_state.pooled_prompt_embeds,
|
583 |
+
block_state.text_ids,
|
584 |
+
) = self.encode_prompt(
|
585 |
+
components,
|
586 |
+
prompt=block_state.prompt,
|
587 |
+
prompt_2=None,
|
588 |
+
prompt_embeds=None,
|
589 |
+
pooled_prompt_embeds=None,
|
590 |
+
device=block_state.device,
|
591 |
+
num_images_per_prompt=1, # TODO: hardcoded for now.
|
592 |
+
max_sequence_length=512,
|
593 |
+
lora_scale=block_state.text_encoder_lora_scale,
|
594 |
+
cache_instance=self
|
595 |
+
if self.cache is not None
|
596 |
+
else None, # Pass self as cache_instance
|
597 |
+
)
|
598 |
+
|
599 |
+
# Add outputs
|
600 |
+
self.set_block_state(state, block_state)
|
601 |
+
return components, state
|
modular_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "CachedFluxTextEncoderStep",
|
3 |
+
"_diffusers_version": "0.35.0.dev0",
|
4 |
+
"auto_map": {
|
5 |
+
"ModularPipelineBlocks": "block.CachedFluxTextEncoderStep"
|
6 |
+
}
|
7 |
+
}
|