neuralvfx commited on
Commit
addfdf3
·
verified ·
1 Parent(s): 6e9fe4d

Initial upload of LibreFlux ControlNet pipeline

Browse files
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .pipeline import (
2
+ LibreFluxControlNetPipeline,
3
+ LibreFluxTransformer2DModel,
4
+ LibreFluxControlNetModel,
5
+ )
backup_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
controlnet/__init__.py ADDED
File without changes
controlnet/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxControlNetModel",
3
+ "_diffusers_version": "0.32.0",
4
+ "attention_head_dim": 128,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 56,
8
+ 56
9
+ ],
10
+ "conditioning_embedding_channels": null,
11
+ "guidance_embeds": true,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 2,
16
+ "num_mode": null,
17
+ "num_single_layers": 4,
18
+ "patch_size": 1,
19
+ "pooled_projection_dim": 768
20
+ }
controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06e84cb264fc8bf98cc6c1ed5e53a606d061c4440c5ba9164f941dfce4f054b6
3
+ size 2739920936
controlnet/net.py ADDED
@@ -0,0 +1,1507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This was modied from the control net repo
16
+
17
+
18
+ import inspect
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
22
+
23
+ import numpy as np
24
+ import torch
25
+ from transformers import (
26
+ CLIPTextModel,
27
+ CLIPTokenizer,
28
+ T5EncoderModel,
29
+ T5TokenizerFast,
30
+ )
31
+
32
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
34
+ from diffusers.models.autoencoders import AutoencoderKL
35
+ ### MERGEING THESE ###
36
+ # from src.models.transformer import FluxTransformer2DModel
37
+ # from src.models.controlnet_flux import FluxControlNetModel
38
+ #############
39
+
40
+ ##########################################
41
+ ########### ATTENTION MERGE ##############
42
+ ##########################################
43
+
44
+ import torch
45
+ from torch import Tensor, FloatTensor
46
+ from torch.nn import functional as F
47
+ from einops import rearrange
48
+ from diffusers.models.attention_processor import Attention
49
+ from diffusers.models.embeddings import apply_rotary_emb
50
+
51
+
52
+
53
+ class FluxFusedSDPAProcessor:
54
+ """
55
+ Fused QKV processor using PyTorch's scaled_dot_product_attention.
56
+ Uses fused projections but splits for attention computation.
57
+ """
58
+
59
+ def __init__(self):
60
+ if not hasattr(F, "scaled_dot_product_attention"):
61
+ raise ImportError(
62
+ "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
63
+ )
64
+
65
+ def __call__(
66
+ self,
67
+ attn,
68
+ hidden_states: FloatTensor,
69
+ encoder_hidden_states: FloatTensor = None,
70
+ attention_mask: FloatTensor = None,
71
+ image_rotary_emb: Tensor = None,
72
+ ) -> FloatTensor:
73
+ input_ndim = hidden_states.ndim
74
+ if input_ndim == 4:
75
+ batch_size, channel, height, width = hidden_states.shape
76
+ hidden_states = hidden_states.view(
77
+ batch_size, channel, height * width
78
+ ).transpose(1, 2)
79
+
80
+ context_input_ndim = (
81
+ encoder_hidden_states.ndim if encoder_hidden_states is not None else None
82
+ )
83
+ if context_input_ndim == 4:
84
+ batch_size, channel, height, width = encoder_hidden_states.shape
85
+ encoder_hidden_states = encoder_hidden_states.view(
86
+ batch_size, channel, height * width
87
+ ).transpose(1, 2)
88
+
89
+ batch_size = (
90
+ encoder_hidden_states.shape[0]
91
+ if encoder_hidden_states is not None
92
+ else hidden_states.shape[0]
93
+ )
94
+
95
+ # Single attention case (no encoder states)
96
+ if encoder_hidden_states is None:
97
+ # Use fused QKV projection
98
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
99
+ inner_dim = qkv.shape[-1] // 3
100
+ head_dim = inner_dim // attn.heads
101
+ seq_len = hidden_states.shape[1]
102
+
103
+ # Split and reshape
104
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
105
+ query, key, value = qkv.unbind(
106
+ dim=2
107
+ ) # Each is (batch, seq_len, heads, head_dim)
108
+
109
+ # Transpose to (batch, heads, seq_len, head_dim)
110
+ query = query.transpose(1, 2)
111
+ key = key.transpose(1, 2)
112
+ value = value.transpose(1, 2)
113
+
114
+ # Apply norms if needed
115
+ if attn.norm_q is not None:
116
+ query = attn.norm_q(query)
117
+ if attn.norm_k is not None:
118
+ key = attn.norm_k(key)
119
+
120
+ # Apply RoPE if needed
121
+ if image_rotary_emb is not None:
122
+ query = apply_rotary_emb(query, image_rotary_emb)
123
+ key = apply_rotary_emb(key, image_rotary_emb)
124
+
125
+ # SDPA
126
+ hidden_states = F.scaled_dot_product_attention(
127
+ query,
128
+ key,
129
+ value,
130
+ attn_mask=attention_mask,
131
+ dropout_p=0.0,
132
+ is_causal=False,
133
+ )
134
+
135
+ # Reshape back
136
+ hidden_states = hidden_states.transpose(1, 2).reshape(
137
+ batch_size, -1, attn.heads * head_dim
138
+ )
139
+ hidden_states = hidden_states.to(query.dtype)
140
+
141
+ if input_ndim == 4:
142
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
143
+ batch_size, channel, height, width
144
+ )
145
+
146
+ return hidden_states
147
+
148
+ # Joint attention case (with encoder states)
149
+ else:
150
+ # Process self-attention QKV
151
+ qkv = attn.to_qkv(hidden_states)
152
+ inner_dim = qkv.shape[-1] // 3
153
+ head_dim = inner_dim // attn.heads
154
+ seq_len = hidden_states.shape[1]
155
+
156
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
157
+ query, key, value = qkv.unbind(dim=2)
158
+
159
+ # Transpose to (batch, heads, seq_len, head_dim)
160
+ query = query.transpose(1, 2)
161
+ key = key.transpose(1, 2)
162
+ value = value.transpose(1, 2)
163
+
164
+ # Apply norms if needed
165
+ if attn.norm_q is not None:
166
+ query = attn.norm_q(query)
167
+ if attn.norm_k is not None:
168
+ key = attn.norm_k(key)
169
+
170
+ # Process encoder QKV
171
+ encoder_seq_len = encoder_hidden_states.shape[1]
172
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
173
+ encoder_qkv = encoder_qkv.view(
174
+ batch_size, encoder_seq_len, 3, attn.heads, head_dim
175
+ )
176
+ encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2)
177
+
178
+ # Transpose to (batch, heads, seq_len, head_dim)
179
+ encoder_query = encoder_query.transpose(1, 2)
180
+ encoder_key = encoder_key.transpose(1, 2)
181
+ encoder_value = encoder_value.transpose(1, 2)
182
+
183
+ # Apply encoder norms if needed
184
+ if attn.norm_added_q is not None:
185
+ encoder_query = attn.norm_added_q(encoder_query)
186
+ if attn.norm_added_k is not None:
187
+ encoder_key = attn.norm_added_k(encoder_key)
188
+
189
+ # Concatenate encoder and self-attention
190
+ query = torch.cat([encoder_query, query], dim=2)
191
+ key = torch.cat([encoder_key, key], dim=2)
192
+ value = torch.cat([encoder_value, value], dim=2)
193
+
194
+ # Apply RoPE if needed
195
+ if image_rotary_emb is not None:
196
+ query = apply_rotary_emb(query, image_rotary_emb)
197
+ key = apply_rotary_emb(key, image_rotary_emb)
198
+
199
+ # SDPA
200
+ hidden_states = F.scaled_dot_product_attention(
201
+ query,
202
+ key,
203
+ value,
204
+ attn_mask=attention_mask,
205
+ dropout_p=0.0,
206
+ is_causal=False,
207
+ )
208
+
209
+ # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
210
+ hidden_states = hidden_states.transpose(1, 2).reshape(
211
+ batch_size, -1, attn.heads * head_dim
212
+ )
213
+ hidden_states = hidden_states.to(query.dtype)
214
+
215
+ # Split encoder and self outputs
216
+ encoder_hidden_states = hidden_states[:, :encoder_seq_len]
217
+ hidden_states = hidden_states[:, encoder_seq_len:]
218
+
219
+ # Output projections
220
+ hidden_states = attn.to_out[0](hidden_states)
221
+ hidden_states = attn.to_out[1](hidden_states) # dropout
222
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
223
+
224
+ # Reshape if needed
225
+ if input_ndim == 4:
226
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
227
+ batch_size, channel, height, width
228
+ )
229
+ if context_input_ndim == 4:
230
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
231
+ batch_size, channel, height, width
232
+ )
233
+
234
+ return hidden_states, encoder_hidden_states
235
+
236
+
237
+ class FluxSingleFusedSDPAProcessor:
238
+ """
239
+ Fused QKV processor for single attention (no encoder states).
240
+ Simpler version for self-attention only blocks.
241
+ """
242
+
243
+ def __init__(self):
244
+ if not hasattr(F, "scaled_dot_product_attention"):
245
+ raise ImportError(
246
+ "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
247
+ )
248
+
249
+ def __call__(
250
+ self,
251
+ attn,
252
+ hidden_states: Tensor,
253
+ encoder_hidden_states: Tensor = None,
254
+ attention_mask: FloatTensor = None,
255
+ image_rotary_emb: Tensor = None,
256
+ ) -> Tensor:
257
+ input_ndim = hidden_states.ndim
258
+ if input_ndim == 4:
259
+ batch_size, channel, height, width = hidden_states.shape
260
+ hidden_states = hidden_states.view(
261
+ batch_size, channel, height * width
262
+ ).transpose(1, 2)
263
+
264
+ batch_size, seq_len, _ = hidden_states.shape
265
+
266
+ # Use fused QKV projection
267
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
268
+ inner_dim = qkv.shape[-1] // 3
269
+ head_dim = inner_dim // attn.heads
270
+
271
+ # Split and reshape in one go
272
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
273
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided
274
+ query, key, value = [
275
+ t.contiguous() for t in qkv.unbind(0) # make each view dense
276
+ ]
277
+ # Now each is (batch, heads, seq_len, head_dim)
278
+
279
+ # Apply norms if needed
280
+ if attn.norm_q is not None:
281
+ query = attn.norm_q(query)
282
+ if attn.norm_k is not None:
283
+ key = attn.norm_k(key)
284
+
285
+ # Apply RoPE if needed
286
+ if image_rotary_emb is not None:
287
+ query = apply_rotary_emb(query, image_rotary_emb)
288
+ key = apply_rotary_emb(key, image_rotary_emb)
289
+
290
+ # SDPA
291
+ hidden_states = F.scaled_dot_product_attention(
292
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
293
+ )
294
+
295
+ # Reshape back
296
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
297
+ hidden_states = hidden_states.to(query.dtype)
298
+
299
+ if input_ndim == 4:
300
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
301
+ batch_size, channel, height, width
302
+ )
303
+
304
+ return hidden_states
305
+
306
+ #################################
307
+ ##### TRANSFORMER MERGE #########
308
+ #################################
309
+
310
+ from typing import Any, Dict, List, Optional, Tuple, Union
311
+
312
+ import torch
313
+ import torch.nn as nn
314
+ import torch.nn.functional as F
315
+ import numpy as np
316
+
317
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
318
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
319
+ from diffusers.models.attention import FeedForward
320
+ from diffusers.models.attention_processor import (
321
+ Attention,
322
+ AttentionProcessor,
323
+ )
324
+ from diffusers.models.modeling_utils import ModelMixin
325
+ from diffusers.models.normalization import (
326
+ AdaLayerNormContinuous,
327
+ AdaLayerNormZero,
328
+ AdaLayerNormZeroSingle,
329
+ )
330
+ from diffusers.utils import (
331
+ USE_PEFT_BACKEND,
332
+ is_torch_version,
333
+ logging,
334
+ scale_lora_layers,
335
+ unscale_lora_layers,
336
+ )
337
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
338
+ from diffusers.models.embeddings import (
339
+ CombinedTimestepGuidanceTextProjEmbeddings,
340
+ CombinedTimestepTextProjEmbeddings,
341
+ FluxPosEmbed,
342
+ )
343
+
344
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
345
+ from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
346
+
347
+
348
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
349
+
350
+ is_flash_attn_available = False
351
+
352
+
353
+
354
+ class FluxAttnProcessor2_0:
355
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
356
+
357
+ def __init__(self):
358
+ if not hasattr(F, "scaled_dot_product_attention"):
359
+ raise ImportError(
360
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
361
+ )
362
+
363
+ def __call__(
364
+ self,
365
+ attn: Attention,
366
+ hidden_states: torch.FloatTensor,
367
+ encoder_hidden_states: torch.FloatTensor = None,
368
+ attention_mask: Optional[torch.FloatTensor] = None,
369
+ image_rotary_emb: Optional[torch.Tensor] = None,
370
+ ) -> torch.FloatTensor:
371
+ batch_size, _, _ = (
372
+ hidden_states.shape
373
+ if encoder_hidden_states is None
374
+ else encoder_hidden_states.shape
375
+ )
376
+
377
+ # `sample` projections.
378
+ query = attn.to_q(hidden_states)
379
+ key = attn.to_k(hidden_states)
380
+ value = attn.to_v(hidden_states)
381
+
382
+ inner_dim = key.shape[-1]
383
+ head_dim = inner_dim // attn.heads
384
+
385
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
386
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
387
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
+
389
+ if attn.norm_q is not None:
390
+ query = attn.norm_q(query)
391
+ if attn.norm_k is not None:
392
+ key = attn.norm_k(key)
393
+
394
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
395
+ if encoder_hidden_states is not None:
396
+ # `context` projections.
397
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
398
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
399
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
400
+
401
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
402
+ batch_size, -1, attn.heads, head_dim
403
+ ).transpose(1, 2)
404
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
405
+ batch_size, -1, attn.heads, head_dim
406
+ ).transpose(1, 2)
407
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
408
+ batch_size, -1, attn.heads, head_dim
409
+ ).transpose(1, 2)
410
+
411
+ if attn.norm_added_q is not None:
412
+ encoder_hidden_states_query_proj = attn.norm_added_q(
413
+ encoder_hidden_states_query_proj
414
+ )
415
+ if attn.norm_added_k is not None:
416
+ encoder_hidden_states_key_proj = attn.norm_added_k(
417
+ encoder_hidden_states_key_proj
418
+ )
419
+
420
+ # attention
421
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
422
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
423
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
424
+
425
+ if image_rotary_emb is not None:
426
+ from diffusers.models.embeddings import apply_rotary_emb
427
+
428
+ query = apply_rotary_emb(query, image_rotary_emb)
429
+ key = apply_rotary_emb(key, image_rotary_emb)
430
+
431
+ if attention_mask is not None:
432
+ #print ('Attention Used')
433
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
434
+ attention_mask = (attention_mask > 0).bool()
435
+ # Edit 17 - match attn dtype to query d-type
436
+ attention_mask = attention_mask.to(
437
+ device=hidden_states.device, dtype=query.dtype
438
+ )
439
+
440
+ hidden_states = F.scaled_dot_product_attention(
441
+ query,
442
+ key,
443
+ value,
444
+ dropout_p=0.0,
445
+ is_causal=False,
446
+ attn_mask=attention_mask,
447
+ )
448
+ hidden_states = hidden_states.transpose(1, 2).reshape(
449
+ batch_size, -1, attn.heads * head_dim
450
+ )
451
+ hidden_states = hidden_states.to(query.dtype)
452
+
453
+ if encoder_hidden_states is not None:
454
+ encoder_hidden_states, hidden_states = (
455
+ hidden_states[:, : encoder_hidden_states.shape[1]],
456
+ hidden_states[:, encoder_hidden_states.shape[1] :],
457
+ )
458
+
459
+ # linear proj
460
+ hidden_states = attn.to_out[0](hidden_states)
461
+ # dropout
462
+ hidden_states = attn.to_out[1](hidden_states)
463
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
464
+
465
+ return hidden_states, encoder_hidden_states
466
+ return hidden_states
467
+
468
+
469
+ def expand_flux_attention_mask(
470
+ hidden_states: torch.Tensor,
471
+ attn_mask: torch.Tensor,
472
+ ) -> torch.Tensor:
473
+ """
474
+ Expand a mask so that the image is included.
475
+ """
476
+ bsz = attn_mask.shape[0]
477
+ assert bsz == hidden_states.shape[0]
478
+ residual_seq_len = hidden_states.shape[1]
479
+ mask_seq_len = attn_mask.shape[1]
480
+
481
+ expanded_mask = torch.ones(bsz, residual_seq_len)
482
+ expanded_mask[:, :mask_seq_len] = attn_mask
483
+
484
+ return expanded_mask
485
+
486
+
487
+ @maybe_allow_in_graph
488
+ class FluxSingleTransformerBlock(nn.Module):
489
+ r"""
490
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
491
+
492
+ Reference: https://arxiv.org/abs/2403.03206
493
+
494
+ Parameters:
495
+ dim (`int`): The number of channels in the input and output.
496
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
497
+ attention_head_dim (`int`): The number of channels in each head.
498
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
499
+ processing of `context` conditions.
500
+ """
501
+
502
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
503
+ super().__init__()
504
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
505
+
506
+ self.norm = AdaLayerNormZeroSingle(dim)
507
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
508
+ self.act_mlp = nn.GELU(approximate="tanh")
509
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
510
+
511
+ processor = FluxAttnProcessor2_0()
512
+ self.attn = Attention(
513
+ query_dim=dim,
514
+ cross_attention_dim=None,
515
+ dim_head=attention_head_dim,
516
+ heads=num_attention_heads,
517
+ out_dim=dim,
518
+ bias=True,
519
+ processor=processor,
520
+ qk_norm="rms_norm",
521
+ eps=1e-6,
522
+ pre_only=True,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ hidden_states: torch.FloatTensor,
528
+ temb: torch.FloatTensor,
529
+ image_rotary_emb=None,
530
+ attention_mask: Optional[torch.Tensor] = None,
531
+ ):
532
+ residual = hidden_states
533
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
534
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
535
+
536
+ if attention_mask is not None:
537
+ attention_mask = expand_flux_attention_mask(
538
+ hidden_states,
539
+ attention_mask,
540
+ )
541
+
542
+ attn_output = self.attn(
543
+ hidden_states=norm_hidden_states,
544
+ image_rotary_emb=image_rotary_emb,
545
+ attention_mask=attention_mask,
546
+ )
547
+
548
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
549
+ gate = gate.unsqueeze(1)
550
+ hidden_states = gate * self.proj_out(hidden_states)
551
+ hidden_states = residual + hidden_states
552
+
553
+ if hidden_states.dtype == torch.float16:
554
+ hidden_states = hidden_states.clip(-65504, 65504)
555
+
556
+ return hidden_states
557
+
558
+
559
+ @maybe_allow_in_graph
560
+ class FluxTransformerBlock(nn.Module):
561
+ r"""
562
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
563
+
564
+ Reference: https://arxiv.org/abs/2403.03206
565
+
566
+ Parameters:
567
+ dim (`int`): The number of channels in the input and output.
568
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
569
+ attention_head_dim (`int`): The number of channels in each head.
570
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
571
+ processing of `context` conditions.
572
+ """
573
+
574
+ def __init__(
575
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
576
+ ):
577
+ super().__init__()
578
+
579
+ self.norm1 = AdaLayerNormZero(dim)
580
+
581
+ self.norm1_context = AdaLayerNormZero(dim)
582
+
583
+ if hasattr(F, "scaled_dot_product_attention"):
584
+ processor = FluxAttnProcessor2_0()
585
+ else:
586
+ raise ValueError(
587
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
588
+ )
589
+ self.attn = Attention(
590
+ query_dim=dim,
591
+ cross_attention_dim=None,
592
+ added_kv_proj_dim=dim,
593
+ dim_head=attention_head_dim,
594
+ heads=num_attention_heads,
595
+ out_dim=dim,
596
+ context_pre_only=False,
597
+ bias=True,
598
+ processor=processor,
599
+ qk_norm=qk_norm,
600
+ eps=eps,
601
+ )
602
+
603
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
604
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
605
+
606
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
+ self.ff_context = FeedForward(
608
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
609
+ )
610
+
611
+ # let chunk size default to None
612
+ self._chunk_size = None
613
+ self._chunk_dim = 0
614
+
615
+ def forward(
616
+ self,
617
+ hidden_states: torch.FloatTensor,
618
+ encoder_hidden_states: torch.FloatTensor,
619
+ temb: torch.FloatTensor,
620
+ image_rotary_emb=None,
621
+ attention_mask: Optional[torch.Tensor] = None,
622
+ ):
623
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
624
+ hidden_states, emb=temb
625
+ )
626
+
627
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
628
+ self.norm1_context(encoder_hidden_states, emb=temb)
629
+ )
630
+
631
+ if attention_mask is not None:
632
+ attention_mask = expand_flux_attention_mask(
633
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
634
+ attention_mask,
635
+ )
636
+
637
+ # Attention.
638
+ attention_outputs = self.attn(
639
+ hidden_states=norm_hidden_states,
640
+ encoder_hidden_states=norm_encoder_hidden_states,
641
+ image_rotary_emb=image_rotary_emb,
642
+ attention_mask=attention_mask,
643
+ )
644
+ if len(attention_outputs) == 2:
645
+ attn_output, context_attn_output = attention_outputs
646
+ elif len(attention_outputs) == 3:
647
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
648
+
649
+ # Process attention outputs for the `hidden_states`.
650
+ attn_output = gate_msa.unsqueeze(1) * attn_output
651
+ hidden_states = hidden_states + attn_output
652
+
653
+ norm_hidden_states = self.norm2(hidden_states)
654
+ norm_hidden_states = (
655
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
656
+ )
657
+
658
+ ff_output = self.ff(norm_hidden_states)
659
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
660
+
661
+ hidden_states = hidden_states + ff_output
662
+ if len(attention_outputs) == 3:
663
+ hidden_states = hidden_states + ip_attn_output
664
+
665
+ # Process attention outputs for the `encoder_hidden_states`.
666
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
667
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
668
+
669
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
670
+ norm_encoder_hidden_states = (
671
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
672
+ + c_shift_mlp[:, None]
673
+ )
674
+
675
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
676
+ encoder_hidden_states = (
677
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
678
+ )
679
+
680
+ if encoder_hidden_states.dtype == torch.float16:
681
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
682
+
683
+ return encoder_hidden_states, hidden_states
684
+
685
+
686
+ class LibreFluxTransformer2DModel(
687
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
688
+ ):
689
+ """
690
+ The Transformer model introduced in Flux.
691
+
692
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
693
+
694
+ Parameters:
695
+ patch_size (`int`): Patch size to turn the input data into small patches.
696
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
697
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
698
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
699
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
700
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
701
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
702
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
703
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
704
+ """
705
+
706
+ _supports_gradient_checkpointing = True
707
+
708
+ @register_to_config
709
+ def __init__(
710
+ self,
711
+ patch_size: int = 1,
712
+ in_channels: int = 64,
713
+ num_layers: int = 19,
714
+ num_single_layers: int = 38,
715
+ attention_head_dim: int = 128,
716
+ num_attention_heads: int = 24,
717
+ joint_attention_dim: int = 4096,
718
+ pooled_projection_dim: int = 768,
719
+ guidance_embeds: bool = False,
720
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
721
+ ):
722
+ super().__init__()
723
+ self.out_channels = in_channels
724
+ self.inner_dim = (
725
+ self.config.num_attention_heads * self.config.attention_head_dim
726
+ )
727
+
728
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
729
+ text_time_guidance_cls = (
730
+ CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
731
+ if guidance_embeds
732
+ else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
733
+ )
734
+ self.time_text_embed = text_time_guidance_cls(
735
+ embedding_dim=self.inner_dim,
736
+ pooled_projection_dim=self.config.pooled_projection_dim,
737
+ )
738
+
739
+ self.context_embedder = nn.Linear(
740
+ self.config.joint_attention_dim, self.inner_dim
741
+ )
742
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
743
+
744
+ self.transformer_blocks = nn.ModuleList(
745
+ [
746
+ FluxTransformerBlock(
747
+ dim=self.inner_dim,
748
+ num_attention_heads=self.config.num_attention_heads,
749
+ attention_head_dim=self.config.attention_head_dim,
750
+ )
751
+ for i in range(self.config.num_layers)
752
+ ]
753
+ )
754
+
755
+ self.single_transformer_blocks = nn.ModuleList(
756
+ [
757
+ FluxSingleTransformerBlock(
758
+ dim=self.inner_dim,
759
+ num_attention_heads=self.config.num_attention_heads,
760
+ attention_head_dim=self.config.attention_head_dim,
761
+ )
762
+ for i in range(self.config.num_single_layers)
763
+ ]
764
+ )
765
+
766
+ self.norm_out = AdaLayerNormContinuous(
767
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
768
+ )
769
+ self.proj_out = nn.Linear(
770
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
771
+ )
772
+
773
+ self.gradient_checkpointing = False
774
+ # added for users to disable checkpointing every nth step
775
+ self.gradient_checkpointing_interval = None
776
+
777
+ def set_gradient_checkpointing_interval(self, value: int):
778
+ self.gradient_checkpointing_interval = value
779
+
780
+ @property
781
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
782
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
783
+ r"""
784
+ Returns:
785
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
786
+ indexed by its weight name.
787
+ """
788
+ # set recursively
789
+ processors = {}
790
+
791
+ def fn_recursive_add_processors(
792
+ name: str,
793
+ module: torch.nn.Module,
794
+ processors: Dict[str, AttentionProcessor],
795
+ ):
796
+ if hasattr(module, "get_processor"):
797
+ processors[f"{name}.processor"] = module.get_processor()
798
+
799
+ for sub_name, child in module.named_children():
800
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
801
+
802
+ return processors
803
+
804
+ for name, module in self.named_children():
805
+ fn_recursive_add_processors(name, module, processors)
806
+
807
+ return processors
808
+
809
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
810
+ def set_attn_processor(
811
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
812
+ ):
813
+ r"""
814
+ Sets the attention processor to use to compute attention.
815
+
816
+ Parameters:
817
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
818
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
819
+ for **all** `Attention` layers.
820
+
821
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
822
+ processor. This is strongly recommended when setting trainable attention processors.
823
+
824
+ """
825
+ count = len(self.attn_processors.keys())
826
+
827
+ if isinstance(processor, dict) and len(processor) != count:
828
+ raise ValueError(
829
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
830
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
831
+ )
832
+
833
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
834
+ if hasattr(module, "set_processor"):
835
+ if not isinstance(processor, dict):
836
+ module.set_processor(processor)
837
+ else:
838
+ module.set_processor(processor.pop(f"{name}.processor"))
839
+
840
+ for sub_name, child in module.named_children():
841
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
842
+
843
+ for name, module in self.named_children():
844
+ fn_recursive_attn_processor(name, module, processor)
845
+
846
+ def forward(
847
+ self,
848
+ hidden_states: torch.Tensor,
849
+ encoder_hidden_states: torch.Tensor = None,
850
+ pooled_projections: torch.Tensor = None,
851
+ timestep: torch.LongTensor = None,
852
+ img_ids: torch.Tensor = None,
853
+ txt_ids: torch.Tensor = None,
854
+ guidance: torch.Tensor = None,
855
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
856
+ controlnet_block_samples=None,
857
+ controlnet_single_block_samples=None,
858
+ return_dict: bool = True,
859
+ attention_mask: Optional[torch.Tensor] = None,
860
+ controlnet_blocks_repeat: bool = False,
861
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
862
+ """
863
+ The [`FluxTransformer2DModel`] forward method.
864
+
865
+ Args:
866
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
867
+ Input `hidden_states`.
868
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
869
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
870
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
871
+ from the embeddings of input conditions.
872
+ timestep ( `torch.LongTensor`):
873
+ Used to indicate denoising step.
874
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
875
+ A list of tensors that if specified are added to the residuals of transformer blocks.
876
+ joint_attention_kwargs (`dict`, *optional*):
877
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
878
+ `self.processor` in
879
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
880
+ return_dict (`bool`, *optional*, defaults to `True`):
881
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
882
+ tuple.
883
+
884
+ Returns:
885
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
886
+ `tuple` where the first element is the sample tensor.
887
+ """
888
+ if joint_attention_kwargs is not None:
889
+ joint_attention_kwargs = joint_attention_kwargs.copy()
890
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
891
+ else:
892
+ lora_scale = 1.0
893
+
894
+ if USE_PEFT_BACKEND:
895
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
896
+ scale_lora_layers(self, lora_scale)
897
+ else:
898
+ if (
899
+ joint_attention_kwargs is not None
900
+ and joint_attention_kwargs.get("scale", None) is not None
901
+ ):
902
+ logger.warning(
903
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
904
+ )
905
+ hidden_states = self.x_embedder(hidden_states)
906
+
907
+ timestep = timestep.to(hidden_states.dtype) * 1000
908
+ if guidance is not None:
909
+ guidance = guidance.to(hidden_states.dtype) * 1000
910
+ else:
911
+ guidance = None
912
+
913
+ #print( self.time_text_embed)
914
+ temb = (
915
+ self.time_text_embed(timestep,pooled_projections)
916
+ # Edit 1 # Charlie NOT NEEDED - UNDONE
917
+ if guidance is None
918
+ else self.time_text_embed(timestep, guidance, pooled_projections)
919
+ )
920
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
921
+
922
+ if txt_ids.ndim == 3:
923
+ txt_ids = txt_ids[0]
924
+ if img_ids.ndim == 3:
925
+ img_ids = img_ids[0]
926
+
927
+ ids = torch.cat((txt_ids, img_ids), dim=0)
928
+
929
+ image_rotary_emb = self.pos_embed(ids)
930
+
931
+ # IP adapter
932
+ if (
933
+ joint_attention_kwargs is not None
934
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
935
+ ):
936
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
937
+ "ip_adapter_image_embeds"
938
+ )
939
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
940
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
941
+
942
+ for index_block, block in enumerate(self.transformer_blocks):
943
+ if (
944
+ self.training
945
+ and self.gradient_checkpointing
946
+ and (
947
+ self.gradient_checkpointing_interval is None
948
+ or index_block % self.gradient_checkpointing_interval == 0
949
+ )
950
+ ):
951
+
952
+ def create_custom_forward(module, return_dict=None):
953
+ def custom_forward(*inputs):
954
+ if return_dict is not None:
955
+ return module(*inputs, return_dict=return_dict)
956
+ else:
957
+ return module(*inputs)
958
+
959
+ return custom_forward
960
+
961
+ ckpt_kwargs: Dict[str, Any] = (
962
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
963
+ )
964
+ encoder_hidden_states, hidden_states = (
965
+ torch.utils.checkpoint.checkpoint(
966
+ create_custom_forward(block),
967
+ hidden_states,
968
+ encoder_hidden_states,
969
+ temb,
970
+ image_rotary_emb,
971
+ attention_mask,
972
+ **ckpt_kwargs,
973
+ )
974
+ )
975
+
976
+ else:
977
+ encoder_hidden_states, hidden_states = block(
978
+ hidden_states=hidden_states,
979
+ encoder_hidden_states=encoder_hidden_states,
980
+ temb=temb,
981
+ image_rotary_emb=image_rotary_emb,
982
+ attention_mask=attention_mask,
983
+ )
984
+
985
+ # controlnet residual
986
+ if controlnet_block_samples is not None:
987
+ interval_control = len(self.transformer_blocks) / len(
988
+ controlnet_block_samples
989
+ )
990
+ interval_control = int(np.ceil(interval_control))
991
+ # For Xlabs ControlNet.
992
+ if controlnet_blocks_repeat:
993
+ hidden_states = (
994
+ hidden_states
995
+ + controlnet_block_samples[
996
+ index_block % len(controlnet_block_samples)
997
+ ]
998
+ )
999
+ else:
1000
+ hidden_states = (
1001
+ hidden_states
1002
+ + controlnet_block_samples[index_block // interval_control]
1003
+ )
1004
+
1005
+ # Flux places the text tokens in front of the image tokens in the
1006
+ # sequence.
1007
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1008
+
1009
+ for index_block, block in enumerate(self.single_transformer_blocks):
1010
+ if (
1011
+ self.training
1012
+ and self.gradient_checkpointing
1013
+ or (
1014
+ self.gradient_checkpointing_interval is not None
1015
+ and index_block % self.gradient_checkpointing_interval == 0
1016
+ )
1017
+ ):
1018
+
1019
+ def create_custom_forward(module, return_dict=None):
1020
+ def custom_forward(*inputs):
1021
+ if return_dict is not None:
1022
+ return module(*inputs, return_dict=return_dict)
1023
+ else:
1024
+ return module(*inputs)
1025
+
1026
+ return custom_forward
1027
+
1028
+ ckpt_kwargs: Dict[str, Any] = (
1029
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1030
+ )
1031
+ hidden_states = torch.utils.checkpoint.checkpoint(
1032
+ create_custom_forward(block),
1033
+ hidden_states,
1034
+ temb,
1035
+ image_rotary_emb,
1036
+ attention_mask,
1037
+ **ckpt_kwargs,
1038
+ )
1039
+
1040
+ else:
1041
+ hidden_states = block(
1042
+ hidden_states=hidden_states,
1043
+ temb=temb,
1044
+ image_rotary_emb=image_rotary_emb,
1045
+ attention_mask=attention_mask,
1046
+ )
1047
+
1048
+ # controlnet residual
1049
+ if controlnet_single_block_samples is not None:
1050
+ interval_control = len(self.single_transformer_blocks) / len(
1051
+ controlnet_single_block_samples
1052
+ )
1053
+ interval_control = int(np.ceil(interval_control))
1054
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1055
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1056
+ + controlnet_single_block_samples[index_block // interval_control]
1057
+ )
1058
+
1059
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1060
+
1061
+ hidden_states = self.norm_out(hidden_states, temb)
1062
+ output = self.proj_out(hidden_states)
1063
+
1064
+ if USE_PEFT_BACKEND:
1065
+ # remove `lora_scale` from each PEFT layer
1066
+ unscale_lora_layers(self, lora_scale)
1067
+
1068
+ if not return_dict:
1069
+ return (output,)
1070
+
1071
+ return Transformer2DModelOutput(sample=output)
1072
+
1073
+ ####################################
1074
+ ##### CONTROL NET MODEL MERGE ######
1075
+ ####################################
1076
+
1077
+
1078
+ from dataclasses import dataclass
1079
+ from typing import Any, Dict, List, Optional, Tuple, Union
1080
+
1081
+ import torch
1082
+ import torch.nn as nn
1083
+
1084
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
1085
+ from diffusers.loaders import PeftAdapterMixin
1086
+ from diffusers.models.attention_processor import AttentionProcessor
1087
+ from diffusers.models.modeling_utils import ModelMixin
1088
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
1089
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
1090
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
1091
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
1092
+
1093
+
1094
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1095
+
1096
+
1097
+ @dataclass
1098
+ class FluxControlNetOutput(BaseOutput):
1099
+ controlnet_block_samples: Tuple[torch.Tensor]
1100
+ controlnet_single_block_samples: Tuple[torch.Tensor]
1101
+
1102
+
1103
+ class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
1104
+ _supports_gradient_checkpointing = True
1105
+
1106
+ @register_to_config
1107
+ def __init__(
1108
+ self,
1109
+ patch_size: int = 1,
1110
+ in_channels: int = 64,
1111
+ num_layers: int = 19,
1112
+ num_single_layers: int = 38,
1113
+ attention_head_dim: int = 128,
1114
+ num_attention_heads: int = 24,
1115
+ joint_attention_dim: int = 4096,
1116
+ pooled_projection_dim: int = 768,
1117
+ guidance_embeds: bool = False,
1118
+ axes_dims_rope: List[int] = [16, 56, 56],
1119
+ num_mode: int = None,
1120
+ conditioning_embedding_channels: int = None,
1121
+ ):
1122
+ super().__init__()
1123
+ self.out_channels = in_channels
1124
+ self.inner_dim = num_attention_heads * attention_head_dim
1125
+
1126
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
1127
+
1128
+ # edit 19
1129
+ #text_time_guidance_cls = (
1130
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
1131
+ #)
1132
+
1133
+ text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings
1134
+ text_time_cls = CombinedTimestepTextProjEmbeddings
1135
+
1136
+ self.time_text_embed = text_time_cls(
1137
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
1138
+ )
1139
+ self.time_text_guidance_embed = text_time_guidance_cls(
1140
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
1141
+ )
1142
+
1143
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
1144
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
1145
+
1146
+ self.transformer_blocks = nn.ModuleList(
1147
+ [
1148
+ FluxTransformerBlock(
1149
+ dim=self.inner_dim,
1150
+ num_attention_heads=num_attention_heads,
1151
+ attention_head_dim=attention_head_dim,
1152
+ )
1153
+ for i in range(num_layers)
1154
+ ]
1155
+ )
1156
+
1157
+ self.single_transformer_blocks = nn.ModuleList(
1158
+ [
1159
+ FluxSingleTransformerBlock(
1160
+ dim=self.inner_dim,
1161
+ num_attention_heads=num_attention_heads,
1162
+ attention_head_dim=attention_head_dim,
1163
+ )
1164
+ for i in range(num_single_layers)
1165
+ ]
1166
+ )
1167
+
1168
+ # controlnet_blocks
1169
+ self.controlnet_blocks = nn.ModuleList([])
1170
+ for _ in range(len(self.transformer_blocks)):
1171
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
1172
+
1173
+ self.controlnet_single_blocks = nn.ModuleList([])
1174
+ for _ in range(len(self.single_transformer_blocks)):
1175
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
1176
+
1177
+ self.union = num_mode is not None
1178
+ if self.union:
1179
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
1180
+
1181
+ if conditioning_embedding_channels is not None:
1182
+ self.input_hint_block = ControlNetConditioningEmbedding(
1183
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
1184
+ )
1185
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
1186
+ else:
1187
+ self.input_hint_block = None
1188
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
1189
+
1190
+ self.gradient_checkpointing = False
1191
+
1192
+ @property
1193
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1194
+ def attn_processors(self):
1195
+ r"""
1196
+ Returns:
1197
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1198
+ indexed by its weight name.
1199
+ """
1200
+ # set recursively
1201
+ processors = {}
1202
+
1203
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1204
+ if hasattr(module, "get_processor"):
1205
+ processors[f"{name}.processor"] = module.get_processor()
1206
+
1207
+ for sub_name, child in module.named_children():
1208
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1209
+
1210
+ return processors
1211
+
1212
+ for name, module in self.named_children():
1213
+ fn_recursive_add_processors(name, module, processors)
1214
+
1215
+ return processors
1216
+
1217
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1218
+ def set_attn_processor(self, processor):
1219
+ r"""
1220
+ Sets the attention processor to use to compute attention.
1221
+
1222
+ Parameters:
1223
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1224
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1225
+ for **all** `Attention` layers.
1226
+
1227
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1228
+ processor. This is strongly recommended when setting trainable attention processors.
1229
+
1230
+ """
1231
+ count = len(self.attn_processors.keys())
1232
+
1233
+ if isinstance(processor, dict) and len(processor) != count:
1234
+ raise ValueError(
1235
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1236
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1237
+ )
1238
+
1239
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1240
+ if hasattr(module, "set_processor"):
1241
+ if not isinstance(processor, dict):
1242
+ module.set_processor(processor)
1243
+ else:
1244
+ module.set_processor(processor.pop(f"{name}.processor"))
1245
+
1246
+ for sub_name, child in module.named_children():
1247
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1248
+
1249
+ for name, module in self.named_children():
1250
+ fn_recursive_attn_processor(name, module, processor)
1251
+
1252
+ def _set_gradient_checkpointing(self, module, value=False):
1253
+ if hasattr(module, "gradient_checkpointing"):
1254
+ module.gradient_checkpointing = value
1255
+
1256
+ @classmethod
1257
+ def from_transformer(
1258
+ cls,
1259
+ transformer,
1260
+ num_layers: int = 4,
1261
+ num_single_layers: int = 10,
1262
+ attention_head_dim: int = 128,
1263
+ num_attention_heads: int = 24,
1264
+ load_weights_from_transformer=True,
1265
+ ):
1266
+ config = dict(transformer.config)
1267
+ config["num_layers"] = num_layers
1268
+ config["num_single_layers"] = num_single_layers
1269
+ config["attention_head_dim"] = attention_head_dim
1270
+ config["num_attention_heads"] = num_attention_heads
1271
+
1272
+ controlnet = cls.from_config(config)
1273
+
1274
+ if load_weights_from_transformer:
1275
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
1276
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
1277
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
1278
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
1279
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
1280
+ controlnet.single_transformer_blocks.load_state_dict(
1281
+ transformer.single_transformer_blocks.state_dict(), strict=False
1282
+ )
1283
+
1284
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
1285
+
1286
+ return controlnet
1287
+
1288
+ # Edit 13 Adding attention masking to forward
1289
+ def forward(
1290
+ self,
1291
+ hidden_states: torch.Tensor,
1292
+ controlnet_cond: torch.Tensor,
1293
+ controlnet_mode: torch.Tensor = None,
1294
+ conditioning_scale: float = 1.0,
1295
+ encoder_hidden_states: torch.Tensor = None,
1296
+ pooled_projections: torch.Tensor = None,
1297
+ timestep: torch.LongTensor = None,
1298
+ img_ids: torch.Tensor = None,
1299
+ txt_ids: torch.Tensor = None,
1300
+ guidance: torch.Tensor = None,
1301
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1302
+ return_dict: bool = True,
1303
+ attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE
1304
+
1305
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1306
+ """
1307
+ The [`FluxTransformer2DModel`] forward method.
1308
+
1309
+ Args:
1310
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1311
+ Input `hidden_states`.
1312
+ controlnet_cond (`torch.Tensor`):
1313
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
1314
+ controlnet_mode (`torch.Tensor`):
1315
+ The mode tensor of shape `(batch_size, 1)`.
1316
+ conditioning_scale (`float`, defaults to `1.0`):
1317
+ The scale factor for ControlNet outputs.
1318
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1319
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1320
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1321
+ from the embeddings of input conditions.
1322
+ timestep ( `torch.LongTensor`):
1323
+ Used to indicate denoising step.
1324
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1325
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1326
+ joint_attention_kwargs (`dict`, *optional*):
1327
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1328
+ `self.processor` in
1329
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1330
+ return_dict (`bool`, *optional*, defaults to `True`):
1331
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1332
+ tuple.
1333
+
1334
+ Returns:
1335
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1336
+ `tuple` where the first element is the sample tensor.
1337
+ """
1338
+ if joint_attention_kwargs is not None:
1339
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1340
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1341
+ else:
1342
+ lora_scale = 1.0
1343
+
1344
+ if USE_PEFT_BACKEND:
1345
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1346
+ scale_lora_layers(self, lora_scale)
1347
+ else:
1348
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1349
+ logger.warning(
1350
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1351
+ )
1352
+ hidden_states = self.x_embedder(hidden_states)
1353
+
1354
+ if self.input_hint_block is not None:
1355
+ controlnet_cond = self.input_hint_block(controlnet_cond)
1356
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
1357
+ height = height_pw // self.config.patch_size
1358
+ width = width_pw // self.config.patch_size
1359
+ controlnet_cond = controlnet_cond.reshape(
1360
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
1361
+ )
1362
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
1363
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
1364
+ # add
1365
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
1366
+
1367
+ timestep = timestep.to(hidden_states.dtype) * 1000
1368
+ if guidance is not None:
1369
+ guidance = guidance.to(hidden_states.dtype) * 1000
1370
+ else:
1371
+ guidance = None
1372
+
1373
+ #print ('Guidance:', guidance)
1374
+ temb = (
1375
+ self.time_text_embed(timestep, pooled_projections)
1376
+ if guidance is None
1377
+ # edit 19
1378
+ else self.time_text_guidance_embed(timestep, guidance, pooled_projections)
1379
+ )
1380
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1381
+
1382
+ if self.union:
1383
+ # union mode
1384
+ if controlnet_mode is None:
1385
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
1386
+ # union mode emb
1387
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
1388
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
1389
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
1390
+
1391
+ if txt_ids.ndim == 3:
1392
+ logger.warning(
1393
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1394
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1395
+ )
1396
+ txt_ids = txt_ids[0]
1397
+ if img_ids.ndim == 3:
1398
+ logger.warning(
1399
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1400
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1401
+ )
1402
+ img_ids = img_ids[0]
1403
+
1404
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1405
+ image_rotary_emb = self.pos_embed(ids)
1406
+
1407
+ block_samples = ()
1408
+ for index_block, block in enumerate(self.transformer_blocks):
1409
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1410
+
1411
+ def create_custom_forward(module, return_dict=None):
1412
+ def custom_forward(*inputs):
1413
+ if return_dict is not None:
1414
+ return module(*inputs, return_dict=return_dict)
1415
+ else:
1416
+ return module(*inputs)
1417
+
1418
+ return custom_forward
1419
+
1420
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1421
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1422
+ create_custom_forward(block),
1423
+ hidden_states,
1424
+ encoder_hidden_states,
1425
+ temb,
1426
+ image_rotary_emb,
1427
+ attention_mask, # Edit 13
1428
+ **ckpt_kwargs,
1429
+ )
1430
+
1431
+ else:
1432
+ encoder_hidden_states, hidden_states = block(
1433
+ hidden_states=hidden_states,
1434
+ encoder_hidden_states=encoder_hidden_states,
1435
+ temb=temb,
1436
+ image_rotary_emb=image_rotary_emb,
1437
+ attention_mask=attention_mask, # Edit 13
1438
+
1439
+ )
1440
+ block_samples = block_samples + (hidden_states,)
1441
+
1442
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1443
+
1444
+ single_block_samples = ()
1445
+ for index_block, block in enumerate(self.single_transformer_blocks):
1446
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1447
+
1448
+ def create_custom_forward(module, return_dict=None):
1449
+ def custom_forward(*inputs):
1450
+ if return_dict is not None:
1451
+ return module(*inputs, return_dict=return_dict)
1452
+ else:
1453
+ return module(*inputs)
1454
+
1455
+ return custom_forward
1456
+
1457
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1458
+ hidden_states = torch.utils.checkpoint.checkpoint(
1459
+ create_custom_forward(block),
1460
+ hidden_states,
1461
+ temb,
1462
+ image_rotary_emb,
1463
+ attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13
1464
+ **ckpt_kwargs,
1465
+ )
1466
+
1467
+ else:
1468
+ hidden_states = block(
1469
+ hidden_states=hidden_states,
1470
+ temb=temb,
1471
+ image_rotary_emb=image_rotary_emb,
1472
+ attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13
1473
+
1474
+ )
1475
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
1476
+
1477
+ # controlnet block
1478
+ controlnet_block_samples = ()
1479
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
1480
+ block_sample = controlnet_block(block_sample)
1481
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
1482
+
1483
+ controlnet_single_block_samples = ()
1484
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
1485
+ single_block_sample = controlnet_block(single_block_sample)
1486
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
1487
+
1488
+ # scaling
1489
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
1490
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
1491
+
1492
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
1493
+ controlnet_single_block_samples = (
1494
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
1495
+ )
1496
+
1497
+ if USE_PEFT_BACKEND:
1498
+ # remove `lora_scale` from each PEFT layer
1499
+ unscale_lora_layers(self, lora_scale)
1500
+
1501
+ if not return_dict:
1502
+ return (controlnet_block_samples, controlnet_single_block_samples)
1503
+
1504
+ return FluxControlNetOutput(
1505
+ controlnet_block_samples=controlnet_block_samples,
1506
+ controlnet_single_block_samples=controlnet_single_block_samples,
1507
+ )
model_index.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxControlNetPipeline",
3
+ "_diffusers_version": "0.32.0",
4
+ "controlnet": [
5
+ "net",
6
+ "LibreFluxControlNetModel"
7
+ ],
8
+ "scheduler": [
9
+ "diffusers",
10
+ "FlowMatchEulerDiscreteScheduler"
11
+ ],
12
+ "text_encoder": [
13
+ "transformers",
14
+ "CLIPTextModel"
15
+ ],
16
+ "text_encoder_2": [
17
+ "transformers",
18
+ "T5EncoderModel"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "tokenizer_2": [
25
+ "transformers",
26
+ "T5TokenizerFast"
27
+ ],
28
+ "transformer": [
29
+ "trans",
30
+ "LibreFluxTransformer2DModel"
31
+ ],
32
+ "vae": [
33
+ "diffusers",
34
+ "AutoencoderKL"
35
+ ]
36
+ }
pipeline.py ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This was modied from the control net repo
16
+
17
+
18
+ import inspect
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
22
+
23
+ import numpy as np
24
+ import torch
25
+ from transformers import (
26
+ CLIPTextModel,
27
+ CLIPTokenizer,
28
+ T5EncoderModel,
29
+ T5TokenizerFast,
30
+ )
31
+
32
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
34
+ from diffusers.models.autoencoders import AutoencoderKL
35
+
36
+ from .controlnet.net import LibreFluxControlNetModel
37
+ from .transformer.trans import LibreFluxTransformer2DModel
38
+
39
+ ####################################
40
+ ##### ACTUAL PIPELINE STUFF ########
41
+ ####################################
42
+
43
+
44
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
45
+ from diffusers.utils import (
46
+ USE_PEFT_BACKEND,
47
+ is_torch_xla_available,
48
+ logging,
49
+ replace_example_docstring,
50
+ scale_lora_layers,
51
+ unscale_lora_layers,
52
+ )
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
55
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
56
+
57
+
58
+ if is_torch_xla_available():
59
+ import torch_xla.core.xla_model as xm
60
+
61
+ XLA_AVAILABLE = True
62
+ else:
63
+ XLA_AVAILABLE = False
64
+
65
+ # TODO(Chris): why won't this emit messages at the INFO level???
66
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
67
+
68
+ EXAMPLE_DOC_STRING = """
69
+ Examples:
70
+ ```py
71
+ >>> import torch
72
+ >>> from diffusers.utils import load_image
73
+ >>> from diffusers import FluxControlNetPipeline
74
+ >>> from diffusers import FluxControlNetModel
75
+
76
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
77
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
78
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
79
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
80
+ ... )
81
+ >>> pipe.to("cuda")
82
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
83
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
84
+ >>> image = pipe(
85
+ ... prompt,
86
+ ... control_image=control_image,
87
+ ... controlnet_conditioning_scale=0.6,
88
+ ... num_inference_steps=28,
89
+ ... guidance_scale=3.5,
90
+ ... ).images[0]
91
+ >>> image.save("flux.png")
92
+ ```
93
+ """
94
+
95
+ def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
96
+ if device is None and dtype is None:
97
+ return x
98
+ need_dev = device is not None and str(getattr(x, "device", None)) != str(device)
99
+ need_dt = dtype is not None and getattr(x, "dtype", None) != dtype
100
+ return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x
101
+
102
+
103
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
104
+ def calculate_shift(
105
+ image_seq_len,
106
+ base_seq_len: int = 256,
107
+ max_seq_len: int = 4096,
108
+ base_shift: float = 0.5,
109
+ max_shift: float = 1.16,
110
+ ):
111
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
112
+ b = base_shift - m * base_seq_len
113
+ mu = image_seq_len * m + b
114
+ return mu
115
+
116
+
117
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
118
+ def retrieve_timesteps(
119
+ scheduler,
120
+ num_inference_steps: Optional[int] = None,
121
+ device: Optional[Union[str, torch.device]] = None,
122
+ timesteps: Optional[List[int]] = None,
123
+ sigmas: Optional[List[float]] = None,
124
+ **kwargs,
125
+ ):
126
+ """
127
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
128
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
129
+
130
+ Args:
131
+ scheduler (`SchedulerMixin`):
132
+ The scheduler to get timesteps from.
133
+ num_inference_steps (`int`):
134
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
135
+ must be `None`.
136
+ device (`str` or `torch.device`, *optional*):
137
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
138
+ timesteps (`List[int]`, *optional*):
139
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
140
+ `num_inference_steps` and `sigmas` must be `None`.
141
+ sigmas (`List[float]`, *optional*):
142
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
143
+ `num_inference_steps` and `timesteps` must be `None`.
144
+
145
+ Returns:
146
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
147
+ second element is the number of inference steps.
148
+ """
149
+ if timesteps is not None and sigmas is not None:
150
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
151
+ if timesteps is not None:
152
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
153
+ if not accepts_timesteps:
154
+ raise ValueError(
155
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
+ f" timestep schedules. Please check whether you are using the correct scheduler."
157
+ )
158
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
159
+ timesteps = scheduler.timesteps
160
+ num_inference_steps = len(timesteps)
161
+ elif sigmas is not None:
162
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
163
+ if not accept_sigmas:
164
+ raise ValueError(
165
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
166
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
167
+ )
168
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
169
+ timesteps = scheduler.timesteps
170
+ num_inference_steps = len(timesteps)
171
+ else:
172
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
173
+ timesteps = scheduler.timesteps
174
+ return timesteps, num_inference_steps
175
+
176
+
177
+ class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
178
+ r"""
179
+ The Flux pipeline for text-to-image generation.
180
+
181
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
182
+
183
+ Args:
184
+ transformer ([`FluxTransformer2DModel`]):
185
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
186
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
187
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
188
+ vae ([`AutoencoderKL`]):
189
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
190
+ text_encoder ([`CLIPTextModel`]):
191
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
192
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
193
+ text_encoder_2 ([`T5EncoderModel`]):
194
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
195
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
196
+ tokenizer (`CLIPTokenizer`):
197
+ Tokenizer of class
198
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
199
+ tokenizer_2 (`T5TokenizerFast`):
200
+ Second Tokenizer of class
201
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
202
+ """
203
+
204
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
205
+ _optional_components = []
206
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
207
+
208
+ def __init__(
209
+ self,
210
+ scheduler: FlowMatchEulerDiscreteScheduler,
211
+ vae: AutoencoderKL,
212
+ text_encoder: CLIPTextModel,
213
+ tokenizer: CLIPTokenizer,
214
+ text_encoder_2: T5EncoderModel,
215
+ tokenizer_2: T5TokenizerFast,
216
+ transformer: LibreFluxTransformer2DModel,
217
+ controlnet: Union[
218
+ LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel],
219
+ ],
220
+ ):
221
+ super().__init__()
222
+
223
+ self.register_modules(
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ text_encoder_2=text_encoder_2,
227
+ tokenizer=tokenizer,
228
+ tokenizer_2=tokenizer_2,
229
+ transformer=transformer,
230
+ scheduler=scheduler,
231
+ controlnet=controlnet,
232
+ )
233
+ self.vae_scale_factor = (
234
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
235
+ )
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+ self.tokenizer_max_length = (
238
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
239
+ )
240
+ self.default_sample_size = 64
241
+
242
+ def _get_t5_prompt_embeds(
243
+ self,
244
+ prompt: Union[str, List[str]] = None,
245
+ num_images_per_prompt: int = 1,
246
+ max_sequence_length: int = 512,
247
+ device: Optional[torch.device] = None,
248
+ dtype: Optional[torch.dtype] = None,
249
+ ):
250
+ device = device or self._execution_device
251
+ dtype = dtype or self.text_encoder.dtype
252
+
253
+ prompt = [prompt] if isinstance(prompt, str) else prompt
254
+ batch_size = len(prompt)
255
+
256
+ text_inputs = self.tokenizer_2(
257
+ prompt,
258
+ padding="max_length",
259
+ max_length=max_sequence_length,
260
+ truncation=True,
261
+ return_length=False,
262
+ return_overflowing_tokens=False,
263
+ return_tensors="pt",
264
+ )
265
+ text_input_ids = text_inputs.input_ids
266
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
267
+
268
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
269
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
270
+ logger.warning(
271
+ "The following part of your input was truncated because `max_sequence_length` is set to "
272
+ f" {max_sequence_length} tokens: {removed_text}"
273
+ )
274
+
275
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0]
276
+ #prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
277
+
278
+ dtype = self.text_encoder_2.dtype
279
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
280
+
281
+ _, seq_len, _ = prompt_embeds.shape
282
+
283
+ # duplicate text embeddings for each generation per prompt
284
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
286
+
287
+ # ADD THIS: Get the attention mask and repeat it for each image
288
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype)
289
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
290
+
291
+ # ADD THIS: Return the attention mask
292
+ return prompt_embeds, prompt_attention_mask
293
+
294
+ def _get_clip_prompt_embeds(
295
+ self,
296
+ prompt: Union[str, List[str]],
297
+ num_images_per_prompt: int = 1,
298
+ device: Optional[torch.device] = None,
299
+ ):
300
+ device = device or self._execution_device
301
+
302
+ prompt = [prompt] if isinstance(prompt, str) else prompt
303
+ batch_size = len(prompt)
304
+
305
+ text_inputs = self.tokenizer(
306
+ prompt,
307
+ padding="max_length",
308
+ max_length=self.tokenizer_max_length,
309
+ truncation=True,
310
+ return_overflowing_tokens=False,
311
+ return_length=False,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
318
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
319
+ logger.warning(
320
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
321
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
322
+ )
323
+ prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False)
324
+ #prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
325
+
326
+ # Use pooled output of CLIPTextModel
327
+ prompt_embeds = prompt_embeds.pooler_output
328
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
329
+
330
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
331
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
332
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
333
+
334
+ return prompt_embeds
335
+
336
+ def encode_prompt(
337
+ self,
338
+ prompt: Union[str, List[str]],
339
+ prompt_2: Union[str, List[str]],
340
+ device: Optional[torch.device] = None,
341
+ num_images_per_prompt: int = 1,
342
+ prompt_embeds: Optional[torch.FloatTensor] = None,
343
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
344
+ max_sequence_length: int = 512,
345
+ lora_scale: Optional[float] = None,
346
+ ):
347
+ device = device or self._execution_device
348
+
349
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
350
+ self._lora_scale = lora_scale
351
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
352
+ scale_lora_layers(self.text_encoder, lora_scale)
353
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354
+ scale_lora_layers(self.text_encoder_2, lora_scale)
355
+
356
+ prompt = [prompt] if isinstance(prompt, str) else prompt
357
+
358
+ if prompt_embeds is None:
359
+ prompt_2 = prompt_2 or prompt
360
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361
+
362
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
363
+ prompt=prompt,
364
+ device=device,
365
+ num_images_per_prompt=num_images_per_prompt,
366
+ )
367
+
368
+ # ADD THIS: Initialize mask and capture it from the T5 embedder
369
+ prompt_attention_mask = None
370
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
371
+ prompt=prompt_2,
372
+ num_images_per_prompt=num_images_per_prompt,
373
+ max_sequence_length=max_sequence_length,
374
+ device=device,
375
+ )
376
+
377
+ if self.text_encoder is not None:
378
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379
+ unscale_lora_layers(self.text_encoder, lora_scale)
380
+ if self.text_encoder_2 is not None:
381
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
383
+
384
+ # FIX: Get batch_size and create text_ids with the correct shape
385
+ batch_size = prompt_embeds.shape[0]
386
+ dtype = self.transformer.dtype
387
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
388
+
389
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
390
+
391
+ def check_inputs(
392
+ self,
393
+ prompt,
394
+ prompt_2,
395
+ height,
396
+ width,
397
+ prompt_embeds=None,
398
+ pooled_prompt_embeds=None,
399
+ callback_on_step_end_tensor_inputs=None,
400
+ max_sequence_length=None,
401
+ ):
402
+ if height % 8 != 0 or width % 8 != 0:
403
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
404
+
405
+ if callback_on_step_end_tensor_inputs is not None and not all(
406
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
407
+ ):
408
+ raise ValueError(
409
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
410
+ )
411
+
412
+ if prompt is not None and prompt_embeds is not None:
413
+ raise ValueError(
414
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
415
+ " only forward one of the two."
416
+ )
417
+ elif prompt_2 is not None and prompt_embeds is not None:
418
+ raise ValueError(
419
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
420
+ " only forward one of the two."
421
+ )
422
+ elif prompt is None and prompt_embeds is None:
423
+ raise ValueError(
424
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
425
+ )
426
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
427
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
428
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
429
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
430
+
431
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
432
+ raise ValueError(
433
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
434
+ )
435
+
436
+ if max_sequence_length is not None and max_sequence_length > 512:
437
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
438
+
439
+ @staticmethod
440
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
441
+ # FIX: Correctly creates batched image IDs
442
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
443
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
444
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
445
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
446
+
447
+ latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1)
448
+
449
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:]
450
+
451
+ latent_image_ids = latent_image_ids.reshape(
452
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
453
+ )
454
+
455
+ return latent_image_ids.to(device=device, dtype=dtype)
456
+
457
+ @staticmethod
458
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
459
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
460
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
461
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
462
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
463
+
464
+ return latents
465
+
466
+ @staticmethod
467
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
468
+ def _unpack_latents(latents, height, width, vae_scale_factor):
469
+ batch_size, num_patches, channels = latents.shape
470
+
471
+ height = height // vae_scale_factor
472
+ width = width // vae_scale_factor
473
+
474
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
475
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
476
+
477
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
478
+
479
+ return latents
480
+
481
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
482
+ def prepare_latents(
483
+ self,
484
+ batch_size,
485
+ num_channels_latents,
486
+ height,
487
+ width,
488
+ dtype,
489
+ device,
490
+ generator,
491
+ latents=None,
492
+ ):
493
+ height = 2 * (int(height) // self.vae_scale_factor)
494
+ width = 2 * (int(width) // self.vae_scale_factor)
495
+
496
+ shape = (batch_size, num_channels_latents, height, width)
497
+
498
+ if latents is not None:
499
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
500
+ return latents.to(device=device, dtype=dtype), latent_image_ids
501
+
502
+ if isinstance(generator, list) and len(generator) != batch_size:
503
+ raise ValueError(
504
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
505
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
506
+ )
507
+
508
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
509
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
510
+
511
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
512
+
513
+ return latents, latent_image_ids
514
+
515
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
516
+ def prepare_image(
517
+ self,
518
+ image,
519
+ width,
520
+ height,
521
+ batch_size,
522
+ num_images_per_prompt,
523
+ device,
524
+ dtype,
525
+ do_classifier_free_guidance=False,
526
+ guess_mode=False,
527
+ ):
528
+ if isinstance(image, torch.Tensor):
529
+ pass
530
+ else:
531
+ image = self.image_processor.preprocess(image, height=height, width=width)
532
+
533
+ image_batch_size = image.shape[0]
534
+
535
+ if image_batch_size == 1:
536
+ repeat_by = batch_size
537
+ else:
538
+ # image batch size is the same as prompt batch size
539
+ repeat_by = num_images_per_prompt
540
+
541
+ image = image.repeat_interleave(repeat_by, dim=0)
542
+
543
+ image = image.to(device=device, dtype=dtype)
544
+
545
+ if do_classifier_free_guidance and not guess_mode:
546
+ image = torch.cat([image] * 2)
547
+
548
+ return image
549
+
550
+ @property
551
+ def guidance_scale(self):
552
+ return self._guidance_scale
553
+
554
+ @property
555
+ def joint_attention_kwargs(self):
556
+ return self._joint_attention_kwargs
557
+
558
+ @property
559
+ def num_timesteps(self):
560
+ return self._num_timesteps
561
+
562
+ @property
563
+ def interrupt(self):
564
+ return self._interrupt
565
+
566
+ @torch.no_grad()
567
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
568
+ def __call__(
569
+ self,
570
+ prompt: Union[str, List[str]] = None,
571
+ prompt_2: Optional[Union[str, List[str]]] = None,
572
+ height: Optional[int] = None,
573
+ width: Optional[int] = None,
574
+ num_inference_steps: int = 28,
575
+ timesteps: List[int] = None,
576
+ guidance_scale: float = 7.0,
577
+ control_image: PipelineImageInput = None,
578
+ control_mode: Optional[Union[int, List[int]]] = None,
579
+ control_image_undo_centering: bool = False,
580
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
581
+ num_images_per_prompt: Optional[int] = 1,
582
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
583
+ latents: Optional[torch.FloatTensor] = None,
584
+ prompt_embeds: Optional[torch.FloatTensor] = None,
585
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
586
+ output_type: Optional[str] = "pil",
587
+ return_dict: bool = True,
588
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
589
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
590
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
591
+ max_sequence_length: int = 512,
592
+ negative_prompt: Optional[Union[str, List[str]]] = "",
593
+ negative_prompt_2: Optional[Union[str, List[str]]] = "",
594
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
595
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
596
+ ):
597
+ r"""
598
+ Function invoked when calling the pipeline for generation.
599
+
600
+ Args:
601
+ prompt (`str` or `List[str]`, *optional*):
602
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
603
+ instead.
604
+ prompt_2 (`str` or `List[str]`, *optional*):
605
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
606
+ will be used instead
607
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
608
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
609
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
611
+ num_inference_steps (`int`, *optional*, defaults to 50):
612
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
613
+ expense of slower inference.
614
+ timesteps (`List[int]`, *optional*):
615
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
616
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
617
+ passed will be used. Must be in descending order.
618
+ guidance_scale (`float`, *optional*, defaults to 7.0):
619
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
620
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
621
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
622
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
623
+ usually at the expense of lower image quality.
624
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
625
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
626
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
627
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
628
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
629
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
630
+ images must be passed as a list such that each element of the list can be correctly batched for input
631
+ to a single ControlNet.
632
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
633
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
634
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
635
+ the corresponding scale as a list.
636
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
637
+ The control mode when applying ControlNet-Union.
638
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
639
+ The number of images to generate per prompt.
640
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
642
+ to make generation deterministic.
643
+ latents (`torch.FloatTensor`, *optional*):
644
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
645
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
646
+ tensor will ge generated by sampling using the supplied random `generator`.
647
+ prompt_embeds (`torch.FloatTensor`, *optional*):
648
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
649
+ provided, text embeddings will be generated from `prompt` input argument.
650
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
651
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
652
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
653
+ output_type (`str`, *optional*, defaults to `"pil"`):
654
+ The output format of the generate image. Choose between
655
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
656
+ return_dict (`bool`, *optional*, defaults to `True`):
657
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
658
+ joint_attention_kwargs (`dict`, *optional*):
659
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
+ `self.processor` in
661
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
+ callback_on_step_end (`Callable`, *optional*):
663
+ A function that calls at the end of each denoising steps during the inference. The function is called
664
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
+ `callback_on_step_end_tensor_inputs`.
667
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
668
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672
+
673
+ Examples:
674
+
675
+ Returns:
676
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
677
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
678
+ images.
679
+ """
680
+
681
+ height = height or self.default_sample_size * self.vae_scale_factor
682
+ width = width or self.default_sample_size * self.vae_scale_factor
683
+
684
+ # 1. Check inputs. Raise error if not correct
685
+ self.check_inputs(
686
+ prompt,
687
+ prompt_2,
688
+ height,
689
+ width,
690
+ prompt_embeds=prompt_embeds,
691
+ pooled_prompt_embeds=pooled_prompt_embeds,
692
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
693
+ max_sequence_length=max_sequence_length,
694
+ )
695
+
696
+ self._guidance_scale = guidance_scale
697
+ self._joint_attention_kwargs = joint_attention_kwargs
698
+ self._interrupt = False
699
+
700
+ # 2. Define call parameters
701
+ if prompt is not None and isinstance(prompt, str):
702
+ batch_size = 1
703
+ elif prompt is not None and isinstance(prompt, list):
704
+ batch_size = len(prompt)
705
+ else:
706
+ batch_size = prompt_embeds.shape[0]
707
+
708
+ device = self._execution_device
709
+ dtype = self.transformer.dtype
710
+
711
+ lora_scale = (
712
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
713
+ )
714
+ # 💡 ADD THIS: Capture the attention_mask from encode_prompt
715
+ (
716
+ prompt_embeds,
717
+ pooled_prompt_embeds,
718
+ text_ids,
719
+ attention_mask,
720
+ ) = self.encode_prompt(
721
+ prompt=prompt,
722
+ prompt_2=prompt_2,
723
+ prompt_embeds=prompt_embeds,
724
+ pooled_prompt_embeds=pooled_prompt_embeds,
725
+ device=device,
726
+ num_images_per_prompt=num_images_per_prompt,
727
+ max_sequence_length=max_sequence_length,
728
+ lora_scale=lora_scale,
729
+ )
730
+
731
+ # ✨ FIX: Encode negative prompts for CFG
732
+ do_classifier_free_guidance = guidance_scale > 1.0
733
+ if do_classifier_free_guidance:
734
+ if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
735
+ negative_prompt = negative_prompt or ""
736
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
737
+ (negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt(
738
+ prompt=negative_prompt, prompt_2=negative_prompt_2, device=device,
739
+ num_images_per_prompt=num_images_per_prompt,
740
+ max_sequence_length=max_sequence_length, lora_scale=lora_scale,
741
+ )
742
+
743
+
744
+ # 3. Prepare control image
745
+ num_channels_latents = self.transformer.config.in_channels // 4
746
+
747
+ if type(self.controlnet) == FullyShardedDataParallel:
748
+ inner_module = self.controlnet._fsdp_wrapped_module
749
+ else:
750
+ inner_module = self.controlnet
751
+
752
+ if isinstance(inner_module, LibreFluxControlNetModel):
753
+ control_image = self.prepare_image(
754
+ image=control_image,
755
+ width=width,
756
+ height=height,
757
+ batch_size=batch_size * num_images_per_prompt,
758
+ num_images_per_prompt=num_images_per_prompt,
759
+ device=device,
760
+ dtype=dtype,
761
+ )
762
+
763
+ if control_image_undo_centering:
764
+ if not self.image_processor.do_normalize:
765
+ raise ValueError(
766
+ "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
767
+ )
768
+ control_image = control_image*0.5 + 0.5
769
+
770
+ height, width = control_image.shape[-2:]
771
+
772
+ #logger.warning(
773
+ # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
774
+ #)
775
+
776
+ # vae encode
777
+ control_image = _maybe_to(control_image, device=self.vae.device)
778
+ control_image = self.vae.encode(control_image).latent_dist.sample()
779
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
780
+ control_image = _maybe_to(control_image, device=device)
781
+ # pack
782
+ height_control_image, width_control_image = control_image.shape[2:]
783
+ control_image = self._pack_latents(
784
+ control_image,
785
+ batch_size * num_images_per_prompt,
786
+ num_channels_latents,
787
+ height_control_image,
788
+ width_control_image,
789
+ )
790
+
791
+ # set control mode
792
+ if control_mode is not None:
793
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
794
+ control_mode = control_mode.reshape([-1, 1])
795
+
796
+
797
+ # set control mode
798
+ control_mode_ = []
799
+ if isinstance(control_mode, list):
800
+ for cmode in control_mode:
801
+ if cmode is None:
802
+ control_mode_.append(-1)
803
+ else:
804
+ control_mode_.append(cmode)
805
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
806
+ control_mode = control_mode.reshape([-1, 1])
807
+
808
+ # 4. Prepare latent variables
809
+ num_channels_latents = self.transformer.config.in_channels // 4
810
+ latents, latent_image_ids = self.prepare_latents(
811
+ batch_size * num_images_per_prompt,
812
+ num_channels_latents,
813
+ height,
814
+ width,
815
+ prompt_embeds.dtype,
816
+ device,
817
+ generator,
818
+ latents,
819
+ )
820
+
821
+ # 5. Prepare timesteps
822
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
823
+ image_seq_len = latents.shape[1]
824
+ mu = calculate_shift(
825
+ image_seq_len,
826
+ self.scheduler.config.base_image_seq_len,
827
+ self.scheduler.config.max_image_seq_len,
828
+ self.scheduler.config.base_shift,
829
+ self.scheduler.config.max_shift,
830
+ )
831
+ timesteps, num_inference_steps = retrieve_timesteps(
832
+ self.scheduler,
833
+ num_inference_steps,
834
+ device,
835
+ timesteps,
836
+ sigmas,
837
+ mu=mu,
838
+ )
839
+
840
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
841
+ self._num_timesteps = len(timesteps)
842
+
843
+ # 6. Denoising loop
844
+ target_device = self.transformer.device
845
+ self.controlnet.to(target_device)
846
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
847
+ for i, t in enumerate(timesteps):
848
+ if self.interrupt:
849
+ continue
850
+
851
+
852
+ # FIX: BATCH INPUTS FOR CFG
853
+ if do_classifier_free_guidance:
854
+ latent_model_input = torch.cat([latents] * 2)
855
+ current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
856
+ current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
857
+ current_attention_mask = torch.cat([negative_attention_mask, attention_mask])
858
+ current_text_ids = torch.cat([negative_text_ids, text_ids])
859
+ current_img_ids = torch.cat([latent_image_ids] * 2)
860
+ current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image]
861
+ else:
862
+ latent_model_input = latents
863
+ current_prompt_embeds = prompt_embeds
864
+ current_pooled_embeds = pooled_prompt_embeds
865
+ current_attention_mask = attention_mask
866
+ current_text_ids = text_ids
867
+ current_img_ids = latent_image_ids
868
+ current_control_image = control_image
869
+
870
+ # FIX: Integrate with device handling
871
+ target_device = self.transformer.device
872
+
873
+ # Move all inputs to the target device
874
+ latent_model_input = _maybe_to(latent_model_input, device=target_device)
875
+ current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device)
876
+ current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device)
877
+ current_attention_mask = _maybe_to(current_attention_mask, device=target_device)
878
+ current_text_ids = _maybe_to(current_text_ids, device=target_device)
879
+ current_img_ids = _maybe_to(current_img_ids, device=target_device)
880
+ if isinstance(current_control_image, torch.Tensor):
881
+ current_control_image = _maybe_to(current_control_image, device=target_device)
882
+ else:
883
+ current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ]
884
+ control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None
885
+
886
+ t_model = t.expand(latent_model_input.shape[0]).to(target_device)
887
+
888
+
889
+ # Model calls
890
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
891
+ hidden_states=latent_model_input,
892
+ controlnet_cond=current_control_image,
893
+ controlnet_mode=control_mode,
894
+ conditioning_scale=controlnet_conditioning_scale,
895
+ timestep=(t_model / 1000),
896
+ guidance=None,
897
+ pooled_projections=current_pooled_embeds,
898
+ encoder_hidden_states=current_prompt_embeds,
899
+ attention_mask=current_attention_mask,
900
+ txt_ids=current_text_ids,
901
+ img_ids=current_img_ids,
902
+ joint_attention_kwargs=self.joint_attention_kwargs,
903
+ return_dict=False
904
+ )
905
+
906
+ controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples]
907
+ controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples]
908
+
909
+ noise_pred = self.transformer(
910
+ hidden_states=latent_model_input,
911
+ timestep=(t_model / 1000),
912
+ guidance=None,
913
+ pooled_projections=current_pooled_embeds,
914
+ encoder_hidden_states=current_prompt_embeds,
915
+ attention_mask=current_attention_mask,
916
+ controlnet_block_samples=controlnet_block_samples,
917
+ controlnet_single_block_samples=controlnet_single_block_samples,
918
+ txt_ids=current_text_ids,
919
+ img_ids=current_img_ids,
920
+ joint_attention_kwargs=self.joint_attention_kwargs,
921
+ return_dict=False
922
+ )[0]
923
+
924
+ # FIX: Apply CFG formula
925
+ if do_classifier_free_guidance:
926
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
927
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
928
+
929
+ ## Probably not needed
930
+ #noise_pred = noise_pred.to(latents.device)
931
+
932
+ latents_dtype = latents.dtype
933
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
934
+
935
+ if latents.dtype != latents_dtype:
936
+ if torch.backends.mps.is_available():
937
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
938
+ latents = latents.to(latents_dtype)
939
+
940
+ if callback_on_step_end is not None:
941
+ callback_kwargs = {}
942
+ for k in callback_on_step_end_tensor_inputs:
943
+ callback_kwargs[k] = locals()[k]
944
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
945
+
946
+ latents = callback_outputs.pop("latents", latents)
947
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
948
+
949
+ # call the callback, if provided
950
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
951
+ progress_bar.update()
952
+
953
+ if XLA_AVAILABLE:
954
+ xm.mark_step()
955
+
956
+ if output_type == "latent":
957
+ image = latents
958
+
959
+ else:
960
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
961
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
962
+
963
+ latents = _maybe_to(latents, device=self.vae.device)
964
+ image = self.vae.decode(latents, return_dict=False)[0]
965
+ image = self.image_processor.postprocess(image, output_type=output_type)
966
+
967
+ # Offload all models
968
+ self.maybe_free_model_hooks()
969
+
970
+ if not return_dict:
971
+ return (image,)
972
+
973
+ return FluxPipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "max_image_seq_len": 4096,
7
+ "max_shift": 1.15,
8
+ "num_train_timesteps": 1000,
9
+ "shift": 1.0,
10
+ "use_dynamic_shifting": false
11
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.3",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893d67a23f4693ed42cdab4cbad7fe3e727cf59609c40da28a46b5470f9ed082
3
+ size 246144352
text_encoder_2/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-xxl",
3
+ "architectures": [
4
+ "T5EncoderModel"
5
+ ],
6
+ "classifier_dropout": 0.0,
7
+ "d_ff": 10240,
8
+ "d_kv": 64,
9
+ "d_model": 4096,
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "eos_token_id": 1,
14
+ "feed_forward_proj": "gated-gelu",
15
+ "initializer_factor": 1.0,
16
+ "is_encoder_decoder": true,
17
+ "is_gated_act": true,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "model_type": "t5",
20
+ "num_decoder_layers": 24,
21
+ "num_heads": 64,
22
+ "num_layers": 24,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.43.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32128
32
+ }
text_encoder_2/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec87bffd1923e8b2774a6d240c922a41f6143081d52cf83b8fe39e9d838c893e
3
+ size 4994582224
text_encoder_2/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5640855b301fcdbceddfa90ae8066cd9414aff020552a201a255ecf2059da00
3
+ size 4530066360
text_encoder_2/model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9524621312
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
223
+ "encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
224
+ "shared.weight": "model-00001-of-00002.safetensors"
225
+ }
226
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
tokenizer_2/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
transformer/__init__.py ADDED
File without changes
transformer/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxTransformer2DModel",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "_name_or_path": "/path/to/transformer",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "guidance_embeds": false,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 19,
16
+ "num_single_layers": 38,
17
+ "patch_size": 1,
18
+ "pooled_projection_dim": 768
19
+ }
transformer/diffusion_pytorch_model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c2abe01a326d95bc836cfd5f68167118c0ecb2c8ccbcf5d6de4dbad47ca53c
3
+ size 9962580296
transformer/diffusion_pytorch_model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:828f131e306b17535c8c1d0a3c4aaa06f2a60a80612500da229829242f3ed422
3
+ size 9949328904
transformer/diffusion_pytorch_model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a17988ef4255372dd902ff5742e647f8a60dcc83756d740d1fbcf81d13d38162
3
+ size 3870584832
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
transformer/trans.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################
2
+ ##### TRANSFORMER MERGE #########
3
+ #################################
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
14
+ from diffusers.models.attention import FeedForward
15
+ from diffusers.models.attention_processor import (
16
+ Attention,
17
+ AttentionProcessor,
18
+ )
19
+ from diffusers.models.modeling_utils import ModelMixin
20
+ from diffusers.models.normalization import (
21
+ AdaLayerNormContinuous,
22
+ AdaLayerNormZero,
23
+ AdaLayerNormZeroSingle,
24
+ )
25
+ from diffusers.utils import (
26
+ USE_PEFT_BACKEND,
27
+ is_torch_version,
28
+ logging,
29
+ scale_lora_layers,
30
+ unscale_lora_layers,
31
+ )
32
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
33
+ from diffusers.models.embeddings import (
34
+ CombinedTimestepGuidanceTextProjEmbeddings,
35
+ CombinedTimestepTextProjEmbeddings,
36
+ FluxPosEmbed,
37
+ )
38
+
39
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
40
+ from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+ is_flash_attn_available = False
46
+
47
+
48
+
49
+ class FluxAttnProcessor2_0:
50
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
51
+
52
+ def __init__(self):
53
+ if not hasattr(F, "scaled_dot_product_attention"):
54
+ raise ImportError(
55
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56
+ )
57
+
58
+ def __call__(
59
+ self,
60
+ attn: Attention,
61
+ hidden_states: torch.FloatTensor,
62
+ encoder_hidden_states: torch.FloatTensor = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ image_rotary_emb: Optional[torch.Tensor] = None,
65
+ ) -> torch.FloatTensor:
66
+ batch_size, _, _ = (
67
+ hidden_states.shape
68
+ if encoder_hidden_states is None
69
+ else encoder_hidden_states.shape
70
+ )
71
+
72
+ # `sample` projections.
73
+ query = attn.to_q(hidden_states)
74
+ key = attn.to_k(hidden_states)
75
+ value = attn.to_v(hidden_states)
76
+
77
+ inner_dim = key.shape[-1]
78
+ head_dim = inner_dim // attn.heads
79
+
80
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
81
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
83
+
84
+ if attn.norm_q is not None:
85
+ query = attn.norm_q(query)
86
+ if attn.norm_k is not None:
87
+ key = attn.norm_k(key)
88
+
89
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
90
+ if encoder_hidden_states is not None:
91
+ # `context` projections.
92
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
93
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
94
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
95
+
96
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
97
+ batch_size, -1, attn.heads, head_dim
98
+ ).transpose(1, 2)
99
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
100
+ batch_size, -1, attn.heads, head_dim
101
+ ).transpose(1, 2)
102
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
103
+ batch_size, -1, attn.heads, head_dim
104
+ ).transpose(1, 2)
105
+
106
+ if attn.norm_added_q is not None:
107
+ encoder_hidden_states_query_proj = attn.norm_added_q(
108
+ encoder_hidden_states_query_proj
109
+ )
110
+ if attn.norm_added_k is not None:
111
+ encoder_hidden_states_key_proj = attn.norm_added_k(
112
+ encoder_hidden_states_key_proj
113
+ )
114
+
115
+ # attention
116
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
117
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
118
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
119
+
120
+ if image_rotary_emb is not None:
121
+ from diffusers.models.embeddings import apply_rotary_emb
122
+
123
+ query = apply_rotary_emb(query, image_rotary_emb)
124
+ key = apply_rotary_emb(key, image_rotary_emb)
125
+
126
+ if attention_mask is not None:
127
+ #print ('Attention Used')
128
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
129
+ attention_mask = (attention_mask > 0).bool()
130
+ # Edit 17 - match attn dtype to query d-type
131
+ attention_mask = attention_mask.to(
132
+ device=hidden_states.device, dtype=query.dtype
133
+ )
134
+
135
+ hidden_states = F.scaled_dot_product_attention(
136
+ query,
137
+ key,
138
+ value,
139
+ dropout_p=0.0,
140
+ is_causal=False,
141
+ attn_mask=attention_mask,
142
+ )
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(
144
+ batch_size, -1, attn.heads * head_dim
145
+ )
146
+ hidden_states = hidden_states.to(query.dtype)
147
+
148
+ if encoder_hidden_states is not None:
149
+ encoder_hidden_states, hidden_states = (
150
+ hidden_states[:, : encoder_hidden_states.shape[1]],
151
+ hidden_states[:, encoder_hidden_states.shape[1] :],
152
+ )
153
+
154
+ # linear proj
155
+ hidden_states = attn.to_out[0](hidden_states)
156
+ # dropout
157
+ hidden_states = attn.to_out[1](hidden_states)
158
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
159
+
160
+ return hidden_states, encoder_hidden_states
161
+ return hidden_states
162
+
163
+
164
+ def expand_flux_attention_mask(
165
+ hidden_states: torch.Tensor,
166
+ attn_mask: torch.Tensor,
167
+ ) -> torch.Tensor:
168
+ """
169
+ Expand a mask so that the image is included.
170
+ """
171
+ bsz = attn_mask.shape[0]
172
+ assert bsz == hidden_states.shape[0]
173
+ residual_seq_len = hidden_states.shape[1]
174
+ mask_seq_len = attn_mask.shape[1]
175
+
176
+ expanded_mask = torch.ones(bsz, residual_seq_len)
177
+ expanded_mask[:, :mask_seq_len] = attn_mask
178
+
179
+ return expanded_mask
180
+
181
+
182
+ @maybe_allow_in_graph
183
+ class FluxSingleTransformerBlock(nn.Module):
184
+ r"""
185
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
186
+
187
+ Reference: https://arxiv.org/abs/2403.03206
188
+
189
+ Parameters:
190
+ dim (`int`): The number of channels in the input and output.
191
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
192
+ attention_head_dim (`int`): The number of channels in each head.
193
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
194
+ processing of `context` conditions.
195
+ """
196
+
197
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
198
+ super().__init__()
199
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
200
+
201
+ self.norm = AdaLayerNormZeroSingle(dim)
202
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
203
+ self.act_mlp = nn.GELU(approximate="tanh")
204
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
205
+
206
+ processor = FluxAttnProcessor2_0()
207
+ self.attn = Attention(
208
+ query_dim=dim,
209
+ cross_attention_dim=None,
210
+ dim_head=attention_head_dim,
211
+ heads=num_attention_heads,
212
+ out_dim=dim,
213
+ bias=True,
214
+ processor=processor,
215
+ qk_norm="rms_norm",
216
+ eps=1e-6,
217
+ pre_only=True,
218
+ )
219
+
220
+ def forward(
221
+ self,
222
+ hidden_states: torch.FloatTensor,
223
+ temb: torch.FloatTensor,
224
+ image_rotary_emb=None,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ ):
227
+ residual = hidden_states
228
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
229
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
230
+
231
+ if attention_mask is not None:
232
+ attention_mask = expand_flux_attention_mask(
233
+ hidden_states,
234
+ attention_mask,
235
+ )
236
+
237
+ attn_output = self.attn(
238
+ hidden_states=norm_hidden_states,
239
+ image_rotary_emb=image_rotary_emb,
240
+ attention_mask=attention_mask,
241
+ )
242
+
243
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
244
+ gate = gate.unsqueeze(1)
245
+ hidden_states = gate * self.proj_out(hidden_states)
246
+ hidden_states = residual + hidden_states
247
+
248
+ if hidden_states.dtype == torch.float16:
249
+ hidden_states = hidden_states.clip(-65504, 65504)
250
+
251
+ return hidden_states
252
+
253
+
254
+ @maybe_allow_in_graph
255
+ class FluxTransformerBlock(nn.Module):
256
+ r"""
257
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
258
+
259
+ Reference: https://arxiv.org/abs/2403.03206
260
+
261
+ Parameters:
262
+ dim (`int`): The number of channels in the input and output.
263
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
264
+ attention_head_dim (`int`): The number of channels in each head.
265
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
266
+ processing of `context` conditions.
267
+ """
268
+
269
+ def __init__(
270
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
271
+ ):
272
+ super().__init__()
273
+
274
+ self.norm1 = AdaLayerNormZero(dim)
275
+
276
+ self.norm1_context = AdaLayerNormZero(dim)
277
+
278
+ if hasattr(F, "scaled_dot_product_attention"):
279
+ processor = FluxAttnProcessor2_0()
280
+ else:
281
+ raise ValueError(
282
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
283
+ )
284
+ self.attn = Attention(
285
+ query_dim=dim,
286
+ cross_attention_dim=None,
287
+ added_kv_proj_dim=dim,
288
+ dim_head=attention_head_dim,
289
+ heads=num_attention_heads,
290
+ out_dim=dim,
291
+ context_pre_only=False,
292
+ bias=True,
293
+ processor=processor,
294
+ qk_norm=qk_norm,
295
+ eps=eps,
296
+ )
297
+
298
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
299
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
300
+
301
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
302
+ self.ff_context = FeedForward(
303
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
304
+ )
305
+
306
+ # let chunk size default to None
307
+ self._chunk_size = None
308
+ self._chunk_dim = 0
309
+
310
+ def forward(
311
+ self,
312
+ hidden_states: torch.FloatTensor,
313
+ encoder_hidden_states: torch.FloatTensor,
314
+ temb: torch.FloatTensor,
315
+ image_rotary_emb=None,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ ):
318
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
319
+ hidden_states, emb=temb
320
+ )
321
+
322
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
323
+ self.norm1_context(encoder_hidden_states, emb=temb)
324
+ )
325
+
326
+ if attention_mask is not None:
327
+ attention_mask = expand_flux_attention_mask(
328
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
329
+ attention_mask,
330
+ )
331
+
332
+ # Attention.
333
+ attention_outputs = self.attn(
334
+ hidden_states=norm_hidden_states,
335
+ encoder_hidden_states=norm_encoder_hidden_states,
336
+ image_rotary_emb=image_rotary_emb,
337
+ attention_mask=attention_mask,
338
+ )
339
+ if len(attention_outputs) == 2:
340
+ attn_output, context_attn_output = attention_outputs
341
+ elif len(attention_outputs) == 3:
342
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
343
+
344
+ # Process attention outputs for the `hidden_states`.
345
+ attn_output = gate_msa.unsqueeze(1) * attn_output
346
+ hidden_states = hidden_states + attn_output
347
+
348
+ norm_hidden_states = self.norm2(hidden_states)
349
+ norm_hidden_states = (
350
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
351
+ )
352
+
353
+ ff_output = self.ff(norm_hidden_states)
354
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
355
+
356
+ hidden_states = hidden_states + ff_output
357
+ if len(attention_outputs) == 3:
358
+ hidden_states = hidden_states + ip_attn_output
359
+
360
+ # Process attention outputs for the `encoder_hidden_states`.
361
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
362
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
363
+
364
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
365
+ norm_encoder_hidden_states = (
366
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
367
+ + c_shift_mlp[:, None]
368
+ )
369
+
370
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
371
+ encoder_hidden_states = (
372
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
373
+ )
374
+
375
+ if encoder_hidden_states.dtype == torch.float16:
376
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
377
+
378
+ return encoder_hidden_states, hidden_states
379
+
380
+
381
+ class LibreFluxTransformer2DModel(
382
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
383
+ ):
384
+ """
385
+ The Transformer model introduced in Flux.
386
+
387
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
388
+
389
+ Parameters:
390
+ patch_size (`int`): Patch size to turn the input data into small patches.
391
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
392
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
393
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
394
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
395
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
396
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
397
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
398
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
399
+ """
400
+
401
+ _supports_gradient_checkpointing = True
402
+
403
+ @register_to_config
404
+ def __init__(
405
+ self,
406
+ patch_size: int = 1,
407
+ in_channels: int = 64,
408
+ num_layers: int = 19,
409
+ num_single_layers: int = 38,
410
+ attention_head_dim: int = 128,
411
+ num_attention_heads: int = 24,
412
+ joint_attention_dim: int = 4096,
413
+ pooled_projection_dim: int = 768,
414
+ guidance_embeds: bool = False,
415
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
416
+ ):
417
+ super().__init__()
418
+ self.out_channels = in_channels
419
+ self.inner_dim = (
420
+ self.config.num_attention_heads * self.config.attention_head_dim
421
+ )
422
+
423
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
424
+ text_time_guidance_cls = (
425
+ CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
426
+ if guidance_embeds
427
+ else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
428
+ )
429
+ self.time_text_embed = text_time_guidance_cls(
430
+ embedding_dim=self.inner_dim,
431
+ pooled_projection_dim=self.config.pooled_projection_dim,
432
+ )
433
+
434
+ self.context_embedder = nn.Linear(
435
+ self.config.joint_attention_dim, self.inner_dim
436
+ )
437
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
438
+
439
+ self.transformer_blocks = nn.ModuleList(
440
+ [
441
+ FluxTransformerBlock(
442
+ dim=self.inner_dim,
443
+ num_attention_heads=self.config.num_attention_heads,
444
+ attention_head_dim=self.config.attention_head_dim,
445
+ )
446
+ for i in range(self.config.num_layers)
447
+ ]
448
+ )
449
+
450
+ self.single_transformer_blocks = nn.ModuleList(
451
+ [
452
+ FluxSingleTransformerBlock(
453
+ dim=self.inner_dim,
454
+ num_attention_heads=self.config.num_attention_heads,
455
+ attention_head_dim=self.config.attention_head_dim,
456
+ )
457
+ for i in range(self.config.num_single_layers)
458
+ ]
459
+ )
460
+
461
+ self.norm_out = AdaLayerNormContinuous(
462
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
463
+ )
464
+ self.proj_out = nn.Linear(
465
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
466
+ )
467
+
468
+ self.gradient_checkpointing = False
469
+ # added for users to disable checkpointing every nth step
470
+ self.gradient_checkpointing_interval = None
471
+
472
+ def set_gradient_checkpointing_interval(self, value: int):
473
+ self.gradient_checkpointing_interval = value
474
+
475
+ @property
476
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
477
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
478
+ r"""
479
+ Returns:
480
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
481
+ indexed by its weight name.
482
+ """
483
+ # set recursively
484
+ processors = {}
485
+
486
+ def fn_recursive_add_processors(
487
+ name: str,
488
+ module: torch.nn.Module,
489
+ processors: Dict[str, AttentionProcessor],
490
+ ):
491
+ if hasattr(module, "get_processor"):
492
+ processors[f"{name}.processor"] = module.get_processor()
493
+
494
+ for sub_name, child in module.named_children():
495
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
496
+
497
+ return processors
498
+
499
+ for name, module in self.named_children():
500
+ fn_recursive_add_processors(name, module, processors)
501
+
502
+ return processors
503
+
504
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
505
+ def set_attn_processor(
506
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
507
+ ):
508
+ r"""
509
+ Sets the attention processor to use to compute attention.
510
+
511
+ Parameters:
512
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
513
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
514
+ for **all** `Attention` layers.
515
+
516
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
517
+ processor. This is strongly recommended when setting trainable attention processors.
518
+
519
+ """
520
+ count = len(self.attn_processors.keys())
521
+
522
+ if isinstance(processor, dict) and len(processor) != count:
523
+ raise ValueError(
524
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
525
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
526
+ )
527
+
528
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
529
+ if hasattr(module, "set_processor"):
530
+ if not isinstance(processor, dict):
531
+ module.set_processor(processor)
532
+ else:
533
+ module.set_processor(processor.pop(f"{name}.processor"))
534
+
535
+ for sub_name, child in module.named_children():
536
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
537
+
538
+ for name, module in self.named_children():
539
+ fn_recursive_attn_processor(name, module, processor)
540
+
541
+ def forward(
542
+ self,
543
+ hidden_states: torch.Tensor,
544
+ encoder_hidden_states: torch.Tensor = None,
545
+ pooled_projections: torch.Tensor = None,
546
+ timestep: torch.LongTensor = None,
547
+ img_ids: torch.Tensor = None,
548
+ txt_ids: torch.Tensor = None,
549
+ guidance: torch.Tensor = None,
550
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
551
+ controlnet_block_samples=None,
552
+ controlnet_single_block_samples=None,
553
+ return_dict: bool = True,
554
+ attention_mask: Optional[torch.Tensor] = None,
555
+ controlnet_blocks_repeat: bool = False,
556
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
557
+ """
558
+ The [`FluxTransformer2DModel`] forward method.
559
+
560
+ Args:
561
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
562
+ Input `hidden_states`.
563
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
564
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
565
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
566
+ from the embeddings of input conditions.
567
+ timestep ( `torch.LongTensor`):
568
+ Used to indicate denoising step.
569
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
570
+ A list of tensors that if specified are added to the residuals of transformer blocks.
571
+ joint_attention_kwargs (`dict`, *optional*):
572
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
573
+ `self.processor` in
574
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
575
+ return_dict (`bool`, *optional*, defaults to `True`):
576
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
577
+ tuple.
578
+
579
+ Returns:
580
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
581
+ `tuple` where the first element is the sample tensor.
582
+ """
583
+ if joint_attention_kwargs is not None:
584
+ joint_attention_kwargs = joint_attention_kwargs.copy()
585
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
586
+ else:
587
+ lora_scale = 1.0
588
+
589
+ if USE_PEFT_BACKEND:
590
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
591
+ scale_lora_layers(self, lora_scale)
592
+ else:
593
+ if (
594
+ joint_attention_kwargs is not None
595
+ and joint_attention_kwargs.get("scale", None) is not None
596
+ ):
597
+ logger.warning(
598
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
599
+ )
600
+ hidden_states = self.x_embedder(hidden_states)
601
+
602
+ timestep = timestep.to(hidden_states.dtype) * 1000
603
+ if guidance is not None:
604
+ guidance = guidance.to(hidden_states.dtype) * 1000
605
+ else:
606
+ guidance = None
607
+
608
+ #print( self.time_text_embed)
609
+ temb = (
610
+ self.time_text_embed(timestep,pooled_projections)
611
+ # Edit 1 # Charlie NOT NEEDED - UNDONE
612
+ if guidance is None
613
+ else self.time_text_embed(timestep, guidance, pooled_projections)
614
+ )
615
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
616
+
617
+ if txt_ids.ndim == 3:
618
+ txt_ids = txt_ids[0]
619
+ if img_ids.ndim == 3:
620
+ img_ids = img_ids[0]
621
+
622
+ ids = torch.cat((txt_ids, img_ids), dim=0)
623
+
624
+ image_rotary_emb = self.pos_embed(ids)
625
+
626
+ # IP adapter
627
+ if (
628
+ joint_attention_kwargs is not None
629
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
630
+ ):
631
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
632
+ "ip_adapter_image_embeds"
633
+ )
634
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
635
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
636
+
637
+ for index_block, block in enumerate(self.transformer_blocks):
638
+ if (
639
+ self.training
640
+ and self.gradient_checkpointing
641
+ and (
642
+ self.gradient_checkpointing_interval is None
643
+ or index_block % self.gradient_checkpointing_interval == 0
644
+ )
645
+ ):
646
+
647
+ def create_custom_forward(module, return_dict=None):
648
+ def custom_forward(*inputs):
649
+ if return_dict is not None:
650
+ return module(*inputs, return_dict=return_dict)
651
+ else:
652
+ return module(*inputs)
653
+
654
+ return custom_forward
655
+
656
+ ckpt_kwargs: Dict[str, Any] = (
657
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
658
+ )
659
+ encoder_hidden_states, hidden_states = (
660
+ torch.utils.checkpoint.checkpoint(
661
+ create_custom_forward(block),
662
+ hidden_states,
663
+ encoder_hidden_states,
664
+ temb,
665
+ image_rotary_emb,
666
+ attention_mask,
667
+ **ckpt_kwargs,
668
+ )
669
+ )
670
+
671
+ else:
672
+ encoder_hidden_states, hidden_states = block(
673
+ hidden_states=hidden_states,
674
+ encoder_hidden_states=encoder_hidden_states,
675
+ temb=temb,
676
+ image_rotary_emb=image_rotary_emb,
677
+ attention_mask=attention_mask,
678
+ )
679
+
680
+ # controlnet residual
681
+ if controlnet_block_samples is not None:
682
+ interval_control = len(self.transformer_blocks) / len(
683
+ controlnet_block_samples
684
+ )
685
+ interval_control = int(np.ceil(interval_control))
686
+ # For Xlabs ControlNet.
687
+ if controlnet_blocks_repeat:
688
+ hidden_states = (
689
+ hidden_states
690
+ + controlnet_block_samples[
691
+ index_block % len(controlnet_block_samples)
692
+ ]
693
+ )
694
+ else:
695
+ hidden_states = (
696
+ hidden_states
697
+ + controlnet_block_samples[index_block // interval_control]
698
+ )
699
+
700
+ # Flux places the text tokens in front of the image tokens in the
701
+ # sequence.
702
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
703
+
704
+ for index_block, block in enumerate(self.single_transformer_blocks):
705
+ if (
706
+ self.training
707
+ and self.gradient_checkpointing
708
+ or (
709
+ self.gradient_checkpointing_interval is not None
710
+ and index_block % self.gradient_checkpointing_interval == 0
711
+ )
712
+ ):
713
+
714
+ def create_custom_forward(module, return_dict=None):
715
+ def custom_forward(*inputs):
716
+ if return_dict is not None:
717
+ return module(*inputs, return_dict=return_dict)
718
+ else:
719
+ return module(*inputs)
720
+
721
+ return custom_forward
722
+
723
+ ckpt_kwargs: Dict[str, Any] = (
724
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
725
+ )
726
+ hidden_states = torch.utils.checkpoint.checkpoint(
727
+ create_custom_forward(block),
728
+ hidden_states,
729
+ temb,
730
+ image_rotary_emb,
731
+ attention_mask,
732
+ **ckpt_kwargs,
733
+ )
734
+
735
+ else:
736
+ hidden_states = block(
737
+ hidden_states=hidden_states,
738
+ temb=temb,
739
+ image_rotary_emb=image_rotary_emb,
740
+ attention_mask=attention_mask,
741
+ )
742
+
743
+ # controlnet residual
744
+ if controlnet_single_block_samples is not None:
745
+ interval_control = len(self.single_transformer_blocks) / len(
746
+ controlnet_single_block_samples
747
+ )
748
+ interval_control = int(np.ceil(interval_control))
749
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
750
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
751
+ + controlnet_single_block_samples[index_block // interval_control]
752
+ )
753
+
754
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
755
+
756
+ hidden_states = self.norm_out(hidden_states, temb)
757
+ output = self.proj_out(hidden_states)
758
+
759
+ if USE_PEFT_BACKEND:
760
+ # remove `lora_scale` from each PEFT layer
761
+ unscale_lora_layers(self, lora_scale)
762
+
763
+ if not return_dict:
764
+ return (output,)
765
+
766
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "mid_block_add_attention": true,
24
+ "norm_num_groups": 32,
25
+ "out_channels": 3,
26
+ "sample_size": 1024,
27
+ "scaling_factor": 0.3611,
28
+ "shift_factor": 0.1159,
29
+ "up_block_types": [
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D"
34
+ ],
35
+ "use_post_quant_conv": false,
36
+ "use_quant_conv": false
37
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3
3
+ size 167666902