Safetensors
aredden commited on
Commit
1f9e684
·
1 Parent(s): fb7df61

Remove f8 flux, instead configure at load, improved quality & corrected configs

Browse files
float8_quantize.py CHANGED
@@ -424,28 +424,28 @@ def quantize_flow_transformer_and_dispatch_float8(
424
  continue
425
  m_extra.to(device)
426
  m_extra.eval()
427
- if (
428
- isinstance(m_extra, nn.Linear)
429
- and not isinstance(m_extra, (F8Linear, CublasLinear))
430
- and quantize_flow_embedder_layers
431
  ):
432
- setattr(
433
- flow_model,
434
- module,
435
- F8Linear.from_linear(
 
 
 
 
 
 
 
 
 
 
436
  m_extra,
437
  float8_dtype=float8_dtype,
438
  input_float8_dtype=input_float8_dtype,
439
- ),
440
- )
441
- del m_extra
442
- elif module != "final_layer" and not quantize_flow_embedder_layers:
443
- recursive_swap_linears(
444
- m_extra,
445
- float8_dtype=float8_dtype,
446
- input_float8_dtype=input_float8_dtype,
447
- quantize_modulation=quantize_modulation,
448
- )
449
  torch.cuda.empty_cache()
450
  if swap_linears_with_cublaslinear and flow_dtype == torch.float16:
451
  swap_to_cublaslinear(flow_model)
 
424
  continue
425
  m_extra.to(device)
426
  m_extra.eval()
427
+ if isinstance(m_extra, nn.Linear) and not isinstance(
428
+ m_extra, (F8Linear, CublasLinear)
 
 
429
  ):
430
+ if quantize_flow_embedder_layers:
431
+ setattr(
432
+ flow_model,
433
+ module,
434
+ F8Linear.from_linear(
435
+ m_extra,
436
+ float8_dtype=float8_dtype,
437
+ input_float8_dtype=input_float8_dtype,
438
+ ),
439
+ )
440
+ del m_extra
441
+ elif module != "final_layer":
442
+ if quantize_flow_embedder_layers:
443
+ recursive_swap_linears(
444
  m_extra,
445
  float8_dtype=float8_dtype,
446
  input_float8_dtype=input_float8_dtype,
447
+ quantize_modulation=quantize_modulation,
448
+ )
 
 
 
 
 
 
 
 
449
  torch.cuda.empty_cache()
450
  if swap_linears_with_cublaslinear and flow_dtype == torch.float16:
451
  swap_to_cublaslinear(flow_model)
modules/conditioner.py CHANGED
@@ -56,6 +56,16 @@ class HFEmbedder(nn.Module):
56
  self.max_length = max_length
57
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
58
 
 
 
 
 
 
 
 
 
 
 
59
  if self.is_clip:
60
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
61
  version, max_length=max_length
@@ -64,11 +74,6 @@ class HFEmbedder(nn.Module):
64
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
65
  version,
66
  **hf_kwargs,
67
- quantization_config=(
68
- auto_quantization_config(quantization_dtype)
69
- if quantization_dtype
70
- else None
71
- ),
72
  )
73
 
74
  else:
@@ -78,11 +83,6 @@ class HFEmbedder(nn.Module):
78
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
79
  version,
80
  **hf_kwargs,
81
- quantization_config=(
82
- auto_quantization_config(quantization_dtype)
83
- if quantization_dtype
84
- else None
85
- ),
86
  )
87
 
88
  def offload(self):
 
56
  self.max_length = max_length
57
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
58
 
59
+ auto_quant_config = (
60
+ auto_quantization_config(quantization_dtype) if quantization_dtype else None
61
+ )
62
+
63
+ # BNB will move to cuda:0 by default if not specified
64
+ if isinstance(auto_quant_config, BitsAndBytesConfig):
65
+ hf_kwargs["device_map"] = {"": self.device.index}
66
+ if auto_quant_config is not None:
67
+ hf_kwargs["quantization_config"] = auto_quant_config
68
+
69
  if self.is_clip:
70
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
71
  version, max_length=max_length
 
74
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
75
  version,
76
  **hf_kwargs,
 
 
 
 
 
77
  )
78
 
79
  else:
 
83
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
84
  version,
85
  **hf_kwargs,
 
 
 
 
 
86
  )
87
 
88
  def offload(self):
modules/flux_model.py CHANGED
@@ -1,7 +1,11 @@
1
  from collections import namedtuple
2
  import os
 
3
  import torch
4
 
 
 
 
5
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
  torch.backends.cuda.matmul.allow_tf32 = True
7
  torch.backends.cudnn.allow_tf32 = True
@@ -111,11 +115,39 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
111
 
112
 
113
  class MLPEmbedder(nn.Module):
114
- def __init__(self, in_dim: int, hidden_dim: int):
 
 
 
 
115
  super().__init__()
116
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
 
 
 
 
 
 
 
 
 
 
 
 
117
  self.silu = nn.SiLU()
118
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def forward(self, x: Tensor) -> Tensor:
121
  return self.out_layer(self.silu(self.in_layer(x)))
@@ -143,14 +175,38 @@ class QKNorm(torch.nn.Module):
143
 
144
 
145
  class SelfAttention(nn.Module):
146
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
 
 
 
 
 
 
147
  super().__init__()
 
 
148
  self.num_heads = num_heads
149
  head_dim = dim // num_heads
150
 
151
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 
 
 
 
 
 
 
 
152
  self.norm = QKNorm(head_dim)
153
- self.proj = nn.Linear(dim, dim)
 
 
 
 
 
 
 
 
154
  self.K = 3
155
  self.H = self.num_heads
156
  self.KH = self.K * self.H
@@ -173,11 +229,21 @@ ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
173
 
174
 
175
  class Modulation(nn.Module):
176
- def __init__(self, dim: int, double: bool):
177
  super().__init__()
 
 
178
  self.is_double = double
179
  self.multiplier = 6 if double else 3
180
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
 
 
 
 
 
 
 
 
181
  self.act = nn.SiLU()
182
 
183
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
@@ -197,37 +263,83 @@ class DoubleStreamBlock(nn.Module):
197
  mlp_ratio: float,
198
  qkv_bias: bool = False,
199
  dtype: torch.dtype = torch.float16,
 
 
200
  ):
201
  super().__init__()
 
 
202
  self.dtype = dtype
203
 
204
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
205
  self.num_heads = num_heads
206
  self.hidden_size = hidden_size
207
- self.img_mod = Modulation(hidden_size, double=True)
 
 
208
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
209
  self.img_attn = SelfAttention(
210
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
 
 
 
211
  )
212
 
213
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
214
  self.img_mlp = nn.Sequential(
215
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
 
 
 
 
 
 
 
 
216
  nn.GELU(approximate="tanh"),
217
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
 
 
 
 
 
 
 
 
218
  )
219
 
220
- self.txt_mod = Modulation(hidden_size, double=True)
 
 
221
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222
  self.txt_attn = SelfAttention(
223
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
 
 
 
224
  )
225
 
226
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
  self.txt_mlp = nn.Sequential(
228
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
 
 
 
 
 
 
 
 
229
  nn.GELU(approximate="tanh"),
230
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
 
 
 
 
 
 
 
 
231
  )
232
  self.K = 3
233
  self.H = self.num_heads
@@ -296,8 +408,12 @@ class SingleStreamBlock(nn.Module):
296
  mlp_ratio: float = 4.0,
297
  qk_scale: float | None = None,
298
  dtype: torch.dtype = torch.float16,
 
 
299
  ):
300
  super().__init__()
 
 
301
  self.dtype = dtype
302
  self.hidden_dim = hidden_size
303
  self.num_heads = num_heads
@@ -306,9 +422,25 @@ class SingleStreamBlock(nn.Module):
306
 
307
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
308
  # qkv and mlp_in
309
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
 
 
 
 
 
 
 
 
310
  # proj and mlp_out
311
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
 
 
 
 
 
 
 
 
312
 
313
  self.norm = QKNorm(head_dim)
314
 
@@ -316,7 +448,11 @@ class SingleStreamBlock(nn.Module):
316
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
317
 
318
  self.mlp_act = nn.GELU(approximate="tanh")
319
- self.modulation = Modulation(hidden_size, double=False)
 
 
 
 
320
 
321
  self.K = 3
322
  self.H = self.num_heads
@@ -364,50 +500,96 @@ class Flux(nn.Module):
364
  Transformer model for flow matching on sequences.
365
  """
366
 
367
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
368
  super().__init__()
369
 
370
  self.dtype = dtype
371
- self.params = params
372
- self.in_channels = params.in_channels
373
  self.out_channels = self.in_channels
374
- if params.hidden_size % params.num_heads != 0:
 
 
 
 
 
375
  raise ValueError(
376
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
377
  )
378
- pe_dim = params.hidden_size // params.num_heads
379
- if sum(params.axes_dim) != pe_dim:
380
  raise ValueError(
381
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
382
  )
383
- self.hidden_size = params.hidden_size
384
- self.num_heads = params.num_heads
385
  self.pe_embedder = EmbedND(
386
  dim=pe_dim,
387
- theta=params.theta,
388
- axes_dim=params.axes_dim,
389
  dtype=self.dtype,
390
  )
391
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
392
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
393
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  self.guidance_in = (
395
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
396
- if params.guidance_embed
 
 
 
 
 
397
  else nn.Identity()
398
  )
399
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  self.double_blocks = nn.ModuleList(
402
  [
403
  DoubleStreamBlock(
404
  self.hidden_size,
405
  self.num_heads,
406
- mlp_ratio=params.mlp_ratio,
407
- qkv_bias=params.qkv_bias,
408
  dtype=self.dtype,
 
 
409
  )
410
- for _ in range(params.depth)
411
  ]
412
  )
413
 
@@ -416,10 +598,12 @@ class Flux(nn.Module):
416
  SingleStreamBlock(
417
  self.hidden_size,
418
  self.num_heads,
419
- mlp_ratio=params.mlp_ratio,
420
  dtype=self.dtype,
 
 
421
  )
422
- for _ in range(params.depth_single_blocks)
423
  ]
424
  )
425
 
@@ -472,13 +656,17 @@ class Flux(nn.Module):
472
  return img
473
 
474
  @classmethod
475
- def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
 
 
476
  from util import load_config_from_path
477
  from safetensors.torch import load_file
478
 
479
  config = load_config_from_path(path)
480
  with torch.device("meta"):
481
- klass = cls(params=config.params, dtype=dtype).type(dtype)
 
 
482
 
483
  ckpt = load_file(config.ckpt_path, device="cpu")
484
  klass.load_state_dict(ckpt, assign=True)
 
1
  from collections import namedtuple
2
  import os
3
+ from typing import TYPE_CHECKING
4
  import torch
5
 
6
+ if TYPE_CHECKING:
7
+ from util import ModelSpec
8
+
9
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
10
  torch.backends.cuda.matmul.allow_tf32 = True
11
  torch.backends.cudnn.allow_tf32 = True
 
115
 
116
 
117
  class MLPEmbedder(nn.Module):
118
+ def __init__(
119
+ self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False
120
+ ):
121
+ from float8_quantize import F8Linear
122
+
123
  super().__init__()
124
+ self.in_layer = (
125
+ nn.Linear(in_dim, hidden_dim, bias=True)
126
+ if not prequantized
127
+ else (
128
+ F8Linear(
129
+ in_features=in_dim,
130
+ out_features=hidden_dim,
131
+ bias=True,
132
+ )
133
+ if quantized
134
+ else nn.Linear(in_dim, hidden_dim, bias=True)
135
+ )
136
+ )
137
  self.silu = nn.SiLU()
138
+ self.out_layer = (
139
+ nn.Linear(hidden_dim, hidden_dim, bias=True)
140
+ if not prequantized
141
+ else (
142
+ F8Linear(
143
+ in_features=hidden_dim,
144
+ out_features=hidden_dim,
145
+ bias=True,
146
+ )
147
+ if quantized
148
+ else nn.Linear(hidden_dim, hidden_dim, bias=True)
149
+ )
150
+ )
151
 
152
  def forward(self, x: Tensor) -> Tensor:
153
  return self.out_layer(self.silu(self.in_layer(x)))
 
175
 
176
 
177
  class SelfAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim: int,
181
+ num_heads: int = 8,
182
+ qkv_bias: bool = False,
183
+ prequantized: bool = False,
184
+ ):
185
  super().__init__()
186
+ from float8_quantize import F8Linear
187
+
188
  self.num_heads = num_heads
189
  head_dim = dim // num_heads
190
 
191
+ self.qkv = (
192
+ nn.Linear(dim, dim * 3, bias=qkv_bias)
193
+ if not prequantized
194
+ else F8Linear(
195
+ in_features=dim,
196
+ out_features=dim * 3,
197
+ bias=qkv_bias,
198
+ )
199
+ )
200
  self.norm = QKNorm(head_dim)
201
+ self.proj = (
202
+ nn.Linear(dim, dim)
203
+ if not prequantized
204
+ else F8Linear(
205
+ in_features=dim,
206
+ out_features=dim,
207
+ bias=True,
208
+ )
209
+ )
210
  self.K = 3
211
  self.H = self.num_heads
212
  self.KH = self.K * self.H
 
229
 
230
 
231
  class Modulation(nn.Module):
232
+ def __init__(self, dim: int, double: bool, quantized_modulation: bool = False):
233
  super().__init__()
234
+ from float8_quantize import F8Linear
235
+
236
  self.is_double = double
237
  self.multiplier = 6 if double else 3
238
+ self.lin = (
239
+ nn.Linear(dim, self.multiplier * dim, bias=True)
240
+ if not quantized_modulation
241
+ else F8Linear(
242
+ in_features=dim,
243
+ out_features=self.multiplier * dim,
244
+ bias=True,
245
+ )
246
+ )
247
  self.act = nn.SiLU()
248
 
249
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
 
263
  mlp_ratio: float,
264
  qkv_bias: bool = False,
265
  dtype: torch.dtype = torch.float16,
266
+ quantized_modulation: bool = False,
267
+ prequantized: bool = False,
268
  ):
269
  super().__init__()
270
+ from float8_quantize import F8Linear
271
+
272
  self.dtype = dtype
273
 
274
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
275
  self.num_heads = num_heads
276
  self.hidden_size = hidden_size
277
+ self.img_mod = Modulation(
278
+ hidden_size, double=True, quantized_modulation=quantized_modulation
279
+ )
280
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
  self.img_attn = SelfAttention(
282
+ dim=hidden_size,
283
+ num_heads=num_heads,
284
+ qkv_bias=qkv_bias,
285
+ prequantized=prequantized,
286
  )
287
 
288
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
  self.img_mlp = nn.Sequential(
290
+ (
291
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
292
+ if not prequantized
293
+ else F8Linear(
294
+ in_features=hidden_size,
295
+ out_features=mlp_hidden_dim,
296
+ bias=True,
297
+ )
298
+ ),
299
  nn.GELU(approximate="tanh"),
300
+ (
301
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
302
+ if not prequantized
303
+ else F8Linear(
304
+ in_features=mlp_hidden_dim,
305
+ out_features=hidden_size,
306
+ bias=True,
307
+ )
308
+ ),
309
  )
310
 
311
+ self.txt_mod = Modulation(
312
+ hidden_size, double=True, quantized_modulation=quantized_modulation
313
+ )
314
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
315
  self.txt_attn = SelfAttention(
316
+ dim=hidden_size,
317
+ num_heads=num_heads,
318
+ qkv_bias=qkv_bias,
319
+ prequantized=prequantized,
320
  )
321
 
322
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
323
  self.txt_mlp = nn.Sequential(
324
+ (
325
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
326
+ if not prequantized
327
+ else F8Linear(
328
+ in_features=hidden_size,
329
+ out_features=mlp_hidden_dim,
330
+ bias=True,
331
+ )
332
+ ),
333
  nn.GELU(approximate="tanh"),
334
+ (
335
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
336
+ if not prequantized
337
+ else F8Linear(
338
+ in_features=mlp_hidden_dim,
339
+ out_features=hidden_size,
340
+ bias=True,
341
+ )
342
+ ),
343
  )
344
  self.K = 3
345
  self.H = self.num_heads
 
408
  mlp_ratio: float = 4.0,
409
  qk_scale: float | None = None,
410
  dtype: torch.dtype = torch.float16,
411
+ quantized_modulation: bool = False,
412
+ prequantized: bool = False,
413
  ):
414
  super().__init__()
415
+ from float8_quantize import F8Linear
416
+
417
  self.dtype = dtype
418
  self.hidden_dim = hidden_size
419
  self.num_heads = num_heads
 
422
 
423
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
424
  # qkv and mlp_in
425
+ self.linear1 = (
426
+ nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
427
+ if not prequantized
428
+ else F8Linear(
429
+ in_features=hidden_size,
430
+ out_features=hidden_size * 3 + self.mlp_hidden_dim,
431
+ bias=True,
432
+ )
433
+ )
434
  # proj and mlp_out
435
+ self.linear2 = (
436
+ nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
437
+ if not prequantized
438
+ else F8Linear(
439
+ in_features=hidden_size + self.mlp_hidden_dim,
440
+ out_features=hidden_size,
441
+ bias=True,
442
+ )
443
+ )
444
 
445
  self.norm = QKNorm(head_dim)
446
 
 
448
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
449
 
450
  self.mlp_act = nn.GELU(approximate="tanh")
451
+ self.modulation = Modulation(
452
+ hidden_size,
453
+ double=False,
454
+ quantized_modulation=quantized_modulation and prequantized,
455
+ )
456
 
457
  self.K = 3
458
  self.H = self.num_heads
 
500
  Transformer model for flow matching on sequences.
501
  """
502
 
503
+ def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16):
504
  super().__init__()
505
 
506
  self.dtype = dtype
507
+ self.params = config.params
508
+ self.in_channels = config.params.in_channels
509
  self.out_channels = self.in_channels
510
+ prequantized_flow = config.prequantized_flow
511
+ quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
512
+ quantized_modulation = config.quantize_modulation and prequantized_flow
513
+ from float8_quantize import F8Linear
514
+
515
+ if config.params.hidden_size % config.params.num_heads != 0:
516
  raise ValueError(
517
+ f"Hidden size {config.params.hidden_size} must be divisible by num_heads {config.params.num_heads}"
518
  )
519
+ pe_dim = config.params.hidden_size // config.params.num_heads
520
+ if sum(config.params.axes_dim) != pe_dim:
521
  raise ValueError(
522
+ f"Got {config.params.axes_dim} but expected positional dim {pe_dim}"
523
  )
524
+ self.hidden_size = config.params.hidden_size
525
+ self.num_heads = config.params.num_heads
526
  self.pe_embedder = EmbedND(
527
  dim=pe_dim,
528
+ theta=config.params.theta,
529
+ axes_dim=config.params.axes_dim,
530
  dtype=self.dtype,
531
  )
532
+ self.img_in = (
533
+ nn.Linear(self.in_channels, self.hidden_size, bias=True)
534
+ if not prequantized_flow
535
+ else (
536
+ F8Linear(
537
+ in_features=self.in_channels,
538
+ out_features=self.hidden_size,
539
+ bias=True,
540
+ )
541
+ if quantized_embedders
542
+ else nn.Linear(self.in_channels, self.hidden_size, bias=True)
543
+ )
544
+ )
545
+ self.time_in = MLPEmbedder(
546
+ in_dim=256,
547
+ hidden_dim=self.hidden_size,
548
+ prequantized=prequantized_flow,
549
+ quantized=quantized_embedders,
550
+ )
551
+ self.vector_in = MLPEmbedder(
552
+ config.params.vec_in_dim,
553
+ self.hidden_size,
554
+ prequantized=prequantized_flow,
555
+ quantized=quantized_embedders,
556
+ )
557
  self.guidance_in = (
558
+ MLPEmbedder(
559
+ in_dim=256,
560
+ hidden_dim=self.hidden_size,
561
+ prequantized=prequantized_flow,
562
+ quantized=quantized_embedders,
563
+ )
564
+ if config.params.guidance_embed
565
  else nn.Identity()
566
  )
567
+ self.txt_in = (
568
+ nn.Linear(config.params.context_in_dim, self.hidden_size)
569
+ if not quantized_embedders
570
+ else (
571
+ F8Linear(
572
+ in_features=config.params.context_in_dim,
573
+ out_features=self.hidden_size,
574
+ bias=True,
575
+ )
576
+ if quantized_embedders
577
+ else nn.Linear(config.params.context_in_dim, self.hidden_size)
578
+ )
579
+ )
580
 
581
  self.double_blocks = nn.ModuleList(
582
  [
583
  DoubleStreamBlock(
584
  self.hidden_size,
585
  self.num_heads,
586
+ mlp_ratio=config.params.mlp_ratio,
587
+ qkv_bias=config.params.qkv_bias,
588
  dtype=self.dtype,
589
+ quantized_modulation=quantized_modulation,
590
+ prequantized=prequantized_flow,
591
  )
592
+ for _ in range(config.params.depth)
593
  ]
594
  )
595
 
 
598
  SingleStreamBlock(
599
  self.hidden_size,
600
  self.num_heads,
601
+ mlp_ratio=config.params.mlp_ratio,
602
  dtype=self.dtype,
603
+ quantized_modulation=quantized_modulation,
604
+ prequantized=prequantized_flow,
605
  )
606
+ for _ in range(config.params.depth_single_blocks)
607
  ]
608
  )
609
 
 
656
  return img
657
 
658
  @classmethod
659
+ def from_pretrained(
660
+ cls: "Flux", path: str, dtype: torch.dtype = torch.float16
661
+ ) -> "Flux":
662
  from util import load_config_from_path
663
  from safetensors.torch import load_file
664
 
665
  config = load_config_from_path(path)
666
  with torch.device("meta"):
667
+ klass = cls(config=config, dtype=dtype)
668
+ if not config.prequantized_flow:
669
+ klass.type(dtype)
670
 
671
  ckpt = load_file(config.ckpt_path, device="cpu")
672
  klass.load_state_dict(ckpt, assign=True)
modules/flux_model_f8.py DELETED
@@ -1,491 +0,0 @@
1
- from collections import namedtuple
2
- import os
3
- import torch
4
-
5
- DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cudnn.allow_tf32 = True
8
- torch.backends.cudnn.benchmark = True
9
- torch.backends.cudnn.benchmark_limit = 20
10
- torch.set_float32_matmul_precision("high")
11
- import math
12
-
13
- from torch import Tensor, nn
14
- from pydantic import BaseModel
15
- from torch.nn import functional as F
16
- from float8_quantize import F8Linear
17
-
18
- try:
19
- from cublas_ops import CublasLinear
20
- except ImportError:
21
- CublasLinear = nn.Linear
22
-
23
-
24
- class FluxParams(BaseModel):
25
- in_channels: int
26
- vec_in_dim: int
27
- context_in_dim: int
28
- hidden_size: int
29
- mlp_ratio: float
30
- num_heads: int
31
- depth: int
32
- depth_single_blocks: int
33
- axes_dim: list[int]
34
- theta: int
35
- qkv_bias: bool
36
- guidance_embed: bool
37
-
38
-
39
- # attention is always same shape each time it's called per H*W, so compile with fullgraph
40
- # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
41
- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
42
- q, k = apply_rope(q, k, pe)
43
- x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
44
- x = x.reshape(*x.shape[:-2], -1)
45
- return x
46
-
47
-
48
- # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
49
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
- scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
- omega = 1.0 / (theta**scale)
52
- out = torch.einsum("...n,d->...nd", pos, omega)
53
- out = torch.stack(
54
- [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
55
- )
56
- out = out.reshape(*out.shape[:-1], 2, 2)
57
- return out
58
-
59
-
60
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
61
- xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
62
- xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
63
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
64
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
65
- return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
66
-
67
-
68
- class EmbedND(nn.Module):
69
- def __init__(
70
- self,
71
- dim: int,
72
- theta: int,
73
- axes_dim: list[int],
74
- dtype: torch.dtype = torch.bfloat16,
75
- ):
76
- super().__init__()
77
- self.dim = dim
78
- self.theta = theta
79
- self.axes_dim = axes_dim
80
- self.dtype = dtype
81
-
82
- def forward(self, ids: Tensor) -> Tensor:
83
- n_axes = ids.shape[-1]
84
- emb = torch.cat(
85
- [
86
- rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
87
- for i in range(n_axes)
88
- ],
89
- dim=-3,
90
- )
91
-
92
- return emb.unsqueeze(1)
93
-
94
-
95
- def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
96
- """
97
- Create sinusoidal timestep embeddings.
98
- :param t: a 1-D Tensor of N indices, one per batch element.
99
- These may be fractional.
100
- :param dim: the dimension of the output.
101
- :param max_period: controls the minimum frequency of the embeddings.
102
- :return: an (N, D) Tensor of positional embeddings.
103
- """
104
- t = time_factor * t
105
- half = dim // 2
106
- freqs = torch.exp(
107
- -math.log(max_period)
108
- * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
109
- / half
110
- )
111
-
112
- args = t[:, None].float() * freqs[None]
113
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
- if dim % 2:
115
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
- return embedding
117
-
118
-
119
- class MLPEmbedder(nn.Module):
120
- def __init__(self, in_dim: int, hidden_dim: int):
121
- super().__init__()
122
- self.in_layer = F8Linear(in_dim, hidden_dim, bias=True)
123
- self.silu = nn.SiLU()
124
- self.out_layer = F8Linear(hidden_dim, hidden_dim, bias=True)
125
-
126
- def forward(self, x: Tensor) -> Tensor:
127
- return self.out_layer(self.silu(self.in_layer(x)))
128
-
129
-
130
- class RMSNorm(torch.nn.Module):
131
- def __init__(self, dim: int):
132
- super().__init__()
133
- self.scale = nn.Parameter(torch.ones(dim))
134
-
135
- def forward(self, x: Tensor):
136
- return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
137
-
138
-
139
- class QKNorm(torch.nn.Module):
140
- def __init__(self, dim: int):
141
- super().__init__()
142
- self.query_norm = RMSNorm(dim)
143
- self.key_norm = RMSNorm(dim)
144
-
145
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
146
- q = self.query_norm(q)
147
- k = self.key_norm(k)
148
- return q, k
149
-
150
-
151
- class SelfAttention(nn.Module):
152
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
153
- super().__init__()
154
- self.num_heads = num_heads
155
- head_dim = dim // num_heads
156
-
157
- self.qkv = F8Linear(dim, dim * 3, bias=qkv_bias)
158
- self.norm = QKNorm(head_dim)
159
- self.proj = F8Linear(dim, dim)
160
- self.K = 3
161
- self.H = self.num_heads
162
- self.KH = self.K * self.H
163
-
164
- def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
165
- B, L, D = x.shape
166
- q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
167
- return q, k, v
168
-
169
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
170
- qkv = self.qkv(x)
171
- q, k, v = self.rearrange_for_norm(qkv)
172
- q, k = self.norm(q, k, v)
173
- x = attention(q, k, v, pe=pe)
174
- x = self.proj(x)
175
- return x
176
-
177
-
178
- ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
179
-
180
-
181
- class Modulation(nn.Module):
182
- def __init__(self, dim: int, double: bool):
183
- super().__init__()
184
- self.is_double = double
185
- self.multiplier = 6 if double else 3
186
- self.lin = F8Linear(dim, self.multiplier * dim, bias=True)
187
- self.act = nn.SiLU()
188
-
189
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
190
- out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
191
-
192
- return (
193
- ModulationOut(*out[:3]),
194
- ModulationOut(*out[3:]) if self.is_double else None,
195
- )
196
-
197
-
198
- class DoubleStreamBlock(nn.Module):
199
- def __init__(
200
- self,
201
- hidden_size: int,
202
- num_heads: int,
203
- mlp_ratio: float,
204
- qkv_bias: bool = False,
205
- dtype: torch.dtype = torch.float16,
206
- ):
207
- super().__init__()
208
- self.dtype = dtype
209
-
210
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
211
- self.num_heads = num_heads
212
- self.hidden_size = hidden_size
213
- self.img_mod = Modulation(hidden_size, double=True)
214
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
215
- self.img_attn = SelfAttention(
216
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
217
- )
218
-
219
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
220
- self.img_mlp = nn.Sequential(
221
- F8Linear(hidden_size, mlp_hidden_dim, bias=True),
222
- nn.GELU(approximate="tanh"),
223
- F8Linear(mlp_hidden_dim, hidden_size, bias=True),
224
- )
225
-
226
- self.txt_mod = Modulation(hidden_size, double=True)
227
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
228
- self.txt_attn = SelfAttention(
229
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
230
- )
231
-
232
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
- self.txt_mlp = nn.Sequential(
234
- F8Linear(hidden_size, mlp_hidden_dim, bias=True),
235
- nn.GELU(approximate="tanh"),
236
- F8Linear(mlp_hidden_dim, hidden_size, bias=True),
237
- )
238
- self.K = 3
239
- self.H = self.num_heads
240
- self.KH = self.K * self.H
241
-
242
- def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
243
- B, L, D = x.shape
244
- q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
245
- return q, k, v
246
-
247
- def forward(
248
- self,
249
- img: Tensor,
250
- txt: Tensor,
251
- vec: Tensor,
252
- pe: Tensor,
253
- ) -> tuple[Tensor, Tensor]:
254
- img_mod1, img_mod2 = self.img_mod(vec)
255
- txt_mod1, txt_mod2 = self.txt_mod(vec)
256
-
257
- # prepare image for attention
258
- img_modulated = self.img_norm1(img)
259
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
260
- img_qkv = self.img_attn.qkv(img_modulated)
261
- img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
262
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
263
-
264
- # prepare txt for attention
265
- txt_modulated = self.txt_norm1(txt)
266
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
267
- txt_qkv = self.txt_attn.qkv(txt_modulated)
268
- txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
269
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
270
-
271
- q = torch.cat((txt_q, img_q), dim=2)
272
- k = torch.cat((txt_k, img_k), dim=2)
273
- v = torch.cat((txt_v, img_v), dim=2)
274
-
275
- attn = attention(q, k, v, pe=pe)
276
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
277
- # calculate the img bloks
278
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
279
- img = img + img_mod2.gate * self.img_mlp(
280
- (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
281
- ).clamp(min=-384 * 2, max=384 * 2)
282
-
283
- # calculate the txt bloks
284
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
285
- txt = txt + txt_mod2.gate * self.txt_mlp(
286
- (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
287
- ).clamp(min=-384 * 2, max=384 * 2)
288
-
289
- return img, txt
290
-
291
-
292
- class SingleStreamBlock(nn.Module):
293
- """
294
- A DiT block with parallel linear layers as described in
295
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
296
- """
297
-
298
- def __init__(
299
- self,
300
- hidden_size: int,
301
- num_heads: int,
302
- mlp_ratio: float = 4.0,
303
- qk_scale: float | None = None,
304
- dtype: torch.dtype = torch.float16,
305
- ):
306
- super().__init__()
307
- self.dtype = dtype
308
- self.hidden_dim = hidden_size
309
- self.num_heads = num_heads
310
- head_dim = hidden_size // num_heads
311
- self.scale = qk_scale or head_dim**-0.5
312
-
313
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
314
- # qkv and mlp_in
315
- self.linear1 = F8Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
316
- # proj and mlp_out
317
- self.linear2 = F8Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
318
-
319
- self.norm = QKNorm(head_dim)
320
-
321
- self.hidden_size = hidden_size
322
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
323
-
324
- self.mlp_act = nn.GELU(approximate="tanh")
325
- self.modulation = Modulation(hidden_size, double=False)
326
-
327
- self.K = 3
328
- self.H = self.num_heads
329
- self.KH = self.K * self.H
330
-
331
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
332
- mod = self.modulation(vec)[0]
333
- pre_norm = self.pre_norm(x)
334
- x_mod = (1 + mod.scale) * pre_norm + mod.shift
335
- qkv, mlp = torch.split(
336
- self.linear1(x_mod),
337
- [3 * self.hidden_size, self.mlp_hidden_dim],
338
- dim=-1,
339
- )
340
- B, L, D = qkv.shape
341
- q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
342
- q, k = self.norm(q, k, v)
343
- attn = attention(q, k, v, pe=pe)
344
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
345
- min=-384 * 4, max=384 * 4
346
- )
347
- return x + mod.gate * output
348
-
349
-
350
- class LastLayer(nn.Module):
351
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
352
- super().__init__()
353
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
354
- self.linear = CublasLinear(
355
- hidden_size, patch_size * patch_size * out_channels, bias=True
356
- )
357
- self.adaLN_modulation = nn.Sequential(
358
- nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
359
- )
360
-
361
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
362
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
363
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
364
- x = self.linear(x)
365
- return x
366
-
367
-
368
- class Flux(nn.Module):
369
- """
370
- Transformer model for flow matching on sequences.
371
- """
372
-
373
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
374
- super().__init__()
375
-
376
- self.dtype = dtype
377
- self.params = params
378
- self.in_channels = params.in_channels
379
- self.out_channels = self.in_channels
380
- if params.hidden_size % params.num_heads != 0:
381
- raise ValueError(
382
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
383
- )
384
- pe_dim = params.hidden_size // params.num_heads
385
- if sum(params.axes_dim) != pe_dim:
386
- raise ValueError(
387
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
388
- )
389
- self.hidden_size = params.hidden_size
390
- self.num_heads = params.num_heads
391
- self.pe_embedder = EmbedND(
392
- dim=pe_dim,
393
- theta=params.theta,
394
- axes_dim=params.axes_dim,
395
- dtype=self.dtype,
396
- )
397
- self.img_in = F8Linear(self.in_channels, self.hidden_size, bias=True)
398
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
399
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
400
- self.guidance_in = (
401
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
402
- if params.guidance_embed
403
- else nn.Identity()
404
- )
405
- self.txt_in = F8Linear(params.context_in_dim, self.hidden_size)
406
-
407
- self.double_blocks = nn.ModuleList(
408
- [
409
- DoubleStreamBlock(
410
- self.hidden_size,
411
- self.num_heads,
412
- mlp_ratio=params.mlp_ratio,
413
- qkv_bias=params.qkv_bias,
414
- dtype=self.dtype,
415
- )
416
- for _ in range(params.depth)
417
- ]
418
- )
419
-
420
- self.single_blocks = nn.ModuleList(
421
- [
422
- SingleStreamBlock(
423
- self.hidden_size,
424
- self.num_heads,
425
- mlp_ratio=params.mlp_ratio,
426
- dtype=self.dtype,
427
- )
428
- for _ in range(params.depth_single_blocks)
429
- ]
430
- )
431
-
432
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
433
-
434
- def forward(
435
- self,
436
- img: Tensor,
437
- img_ids: Tensor,
438
- txt: Tensor,
439
- txt_ids: Tensor,
440
- timesteps: Tensor,
441
- y: Tensor,
442
- guidance: Tensor | None = None,
443
- ) -> Tensor:
444
- if img.ndim != 3 or txt.ndim != 3:
445
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
446
-
447
- # running on sequences img
448
- img = self.img_in(img)
449
- vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
450
-
451
- if self.params.guidance_embed:
452
- if guidance is None:
453
- raise ValueError(
454
- "Didn't get guidance strength for guidance distilled model."
455
- )
456
- vec = vec + self.guidance_in(
457
- timestep_embedding(guidance, 256).type(self.dtype)
458
- )
459
- vec = vec + self.vector_in(y)
460
-
461
- txt = self.txt_in(txt)
462
-
463
- ids = torch.cat((txt_ids, img_ids), dim=1)
464
- pe = self.pe_embedder(ids)
465
-
466
- # double stream blocks
467
- for block in self.double_blocks:
468
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
469
-
470
- img = torch.cat((txt, img), 1)
471
-
472
- # single stream blocks
473
- for block in self.single_blocks:
474
- img = block(img, vec=vec, pe=pe)
475
-
476
- img = img[:, txt.shape[1] :, ...]
477
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
478
- return img
479
-
480
- @classmethod
481
- def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
482
- from util import load_config_from_path
483
- from safetensors.torch import load_file
484
-
485
- config = load_config_from_path(path)
486
- with torch.device("meta"):
487
- klass = cls(params=config.params, dtype=dtype).type(dtype)
488
-
489
- ckpt = load_file(config.ckpt_path, device="cpu")
490
- klass.load_state_dict(ckpt, assign=True)
491
- return klass.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
9
- from modules.flux_model_f8 import Flux as FluxF8
10
  from safetensors.torch import load_file as load_sft
11
 
12
  try:
@@ -68,7 +67,7 @@ class ModelSpec(BaseModel):
68
  # Improved precision via not quanitzing the modulation linear layers
69
  quantize_modulation: bool = True
70
  # Improved precision via not quanitzing the flow embedder layers
71
- quantize_flow_embedder_layers: bool = True
72
 
73
  model_config: ConfigDict = {
74
  "arbitrary_types_allowed": True,
@@ -230,16 +229,14 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
230
  )
231
 
232
 
233
- def load_flow_model(config: ModelSpec) -> Flux | FluxF8:
234
  ckpt_path = config.ckpt_path
235
  FluxClass = Flux
236
- if config.prequantized_flow:
237
- FluxClass = FluxF8
238
 
239
  with torch.device("meta"):
240
- model = FluxClass(config.params, dtype=into_dtype(config.flow_dtype)).type(
241
- into_dtype(config.flow_dtype)
242
- )
243
 
244
  if ckpt_path is not None:
245
  # load_sft doesn't support torch.device
@@ -290,7 +287,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
290
 
291
 
292
  class LoadedModels(BaseModel):
293
- flow: Flux | FluxF8
294
  ae: AutoEncoder
295
  clip: HFEmbedder
296
  t5: HFEmbedder
 
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
 
9
  from safetensors.torch import load_file as load_sft
10
 
11
  try:
 
67
  # Improved precision via not quanitzing the modulation linear layers
68
  quantize_modulation: bool = True
69
  # Improved precision via not quanitzing the flow embedder layers
70
+ quantize_flow_embedder_layers: bool = False
71
 
72
  model_config: ConfigDict = {
73
  "arbitrary_types_allowed": True,
 
229
  )
230
 
231
 
232
+ def load_flow_model(config: ModelSpec) -> Flux:
233
  ckpt_path = config.ckpt_path
234
  FluxClass = Flux
 
 
235
 
236
  with torch.device("meta"):
237
+ model = FluxClass(config, dtype=into_dtype(config.flow_dtype))
238
+ if not config.prequantized_flow:
239
+ model.type(into_dtype(config.flow_dtype))
240
 
241
  if ckpt_path is not None:
242
  # load_sft doesn't support torch.device
 
287
 
288
 
289
  class LoadedModels(BaseModel):
290
+ flow: Flux
291
  ae: AutoEncoder
292
  clip: HFEmbedder
293
  t5: HFEmbedder