Safetensors
aredden commited on
Commit
af20799
·
unverified ·
2 Parent(s): b84f35e 49f2076

Merge pull request #3 from aredden/improved_precision

Browse files
Files changed (8) hide show
  1. README.md +33 -28
  2. float8_quantize.py +85 -18
  3. flux_pipeline.py +95 -59
  4. main.py +18 -0
  5. modules/conditioner.py +10 -10
  6. modules/flux_model.py +234 -51
  7. modules/flux_model_f8.py +0 -491
  8. util.py +20 -8
README.md CHANGED
@@ -41,6 +41,19 @@ Note:
41
  - [Examples](#examples)
42
  - [License](#license)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ## Installation
45
 
46
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
@@ -106,30 +119,8 @@ python main.py --config-path <path_to_config> --port <port_number> --host <host_
106
  - `--no-offload-ae`: Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]).
107
  - `--no-offload-text-enc`: Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]).
108
  - `--prequantized-flow`: Load the flow model from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: False).
109
-
110
- ## Examples
111
-
112
- ### Running the Server
113
-
114
- ```bash
115
- python main.py --config-path configs/config-dev-1-4090.json --port 8088 --host 0.0.0.0
116
- ```
117
-
118
- Or if you need more granular control over the all of the settings, you can run the server with something like this:
119
-
120
- ```bash
121
- python main.py --port 8088 --host 0.0.0.0 \
122
- --flow-model-path /path/to/your/flux1-dev.sft \
123
- --text-enc-path /path/to/your/t5-v1_1-xxl-encoder-bf16 \
124
- --autoencoder-path /path/to/your/ae.sft \
125
- --model-version flux-dev \
126
- --flux-device cuda:0 \
127
- --text-enc-device cuda:0 \
128
- --autoencoder-device cuda:0 \
129
- --compile \
130
- --quant-text-enc qfloat8 \
131
- --quant-ae
132
- ```
133
 
134
  ## Configuration
135
 
@@ -185,7 +176,10 @@ Example configuration file for a single 4090 (`configs/config-dev-offload-1-4090
185
  "compile_blocks": true, // compile the single-blocks and double-blocks
186
  "offload_text_encoder": true, // offload the text encoder to cpu when not in use
187
  "offload_vae": true, // offload the autoencoder to cpu when not in use
188
- "offload_flow": false // offload the flow transformer to cpu when not in use
 
 
 
189
  }
190
  ```
191
 
@@ -232,6 +226,17 @@ Other things to change can be the
232
  - `"ae_device": "cuda:0",`
233
  device for autoencoder (default: cuda:0) - set this to a different device - e.g. `"cuda:1"` if you have multiple gpus so you can set offloading for ae to false, does not need to be the same as flux_device or text_enc_device
234
 
 
 
 
 
 
 
 
 
 
 
 
235
  ## API Endpoints
236
 
237
  ### Generate Image
@@ -256,10 +261,10 @@ Other things to change can be the
256
  ### Running the Server
257
 
258
  ```bash
259
- python main.py --config-path configs/config-dev-offload-1-4090.json --port 8088 --host 0.0.0.0
260
  ```
261
 
262
- OR, if you need more granular control over the server, you can run the server with something like this:
263
 
264
  ```bash
265
  python main.py --port 8088 --host 0.0.0.0 \
@@ -275,7 +280,7 @@ python main.py --port 8088 --host 0.0.0.0 \
275
  --quant-ae
276
  ```
277
 
278
- ### Generating an Image
279
 
280
  Send a POST request to `http://<host>:<port>/generate` with the following JSON body:
281
 
 
41
  - [Examples](#examples)
42
  - [License](#license)
43
 
44
+ ### Updates 08/24/24
45
+
46
+ - Add config options for levels of quantization for the flow transformer:
47
+ - `quantize_modulation`: Quantize the modulation layers in the flow model. If false, adds ~2GB vram usage for moderate precision improvements `(default: true)`
48
+ - `quantize_flow_embedder_layers`: Quantize the flow embedder layers in the flow model. If false, adds ~512MB vram usage, but precision improves considerably. `(default: false)`
49
+ - Override default config values when loading FluxPipeline, e.g. `FluxPipeline.load_pipeline_from_config_path(config_path, **config_overrides)`
50
+
51
+ #### Fixes
52
+
53
+ - Fix bug where loading text encoder from HF with bnb will error if device is not set to cuda:0
54
+
55
+ **note:** prequantized flow models will only work with the specified quantization levels as when they were created. e.g. if you create a prequantized flow model with `quantize_modulation` set to false, it will only work with `quantize_modulation` set to false, same with `quantize_flow_embedder_layers`.
56
+
57
  ## Installation
58
 
59
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
 
119
  - `--no-offload-ae`: Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]).
120
  - `--no-offload-text-enc`: Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]).
121
  - `--prequantized-flow`: Load the flow model from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: False).
122
+ - `--no-quantize-flow-modulation`: Disable quantization of the modulation layers in the flow transformer, which improves precision _moderately_ but adds ~2GB vram usage.
123
+ - `--quantize-flow-embedder-layers`: Quantize the flow embedder layers in the flow transformer, reduces precision _considerably_ but saves ~512MB vram usage.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  ## Configuration
126
 
 
176
  "compile_blocks": true, // compile the single-blocks and double-blocks
177
  "offload_text_encoder": true, // offload the text encoder to cpu when not in use
178
  "offload_vae": true, // offload the autoencoder to cpu when not in use
179
+ "offload_flow": false, // offload the flow transformer to cpu when not in use
180
+ "prequantized_flow": false, // load the flow transformer from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: false)
181
+ "quantize_modulation": true, // quantize the modulation layers in the flow transformer, which reduces precision moderately but saves ~2GB vram usage (default: true)
182
+ "quantize_flow_embedder_layers": false, // quantize the flow embedder layers in the flow transformer, if false, improves precision considerably at the cost of adding ~512MB vram usage (default: false)
183
  }
184
  ```
185
 
 
226
  - `"ae_device": "cuda:0",`
227
  device for autoencoder (default: cuda:0) - set this to a different device - e.g. `"cuda:1"` if you have multiple gpus so you can set offloading for ae to false, does not need to be the same as flux_device or text_enc_device
228
 
229
+ - `"prequantized_flow": false,`
230
+ load the flow transformer from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: false)
231
+
232
+ - Note: MUST be a prequantized checkpoint created with the same quantization settings as the current config, and must have been quantized using this repo.
233
+
234
+ - `"quantize_modulation": true,`
235
+ quantize the modulation layers in the flow transformer, which improves precision at the cost of adding ~2GB vram usage (default: true)
236
+
237
+ - `"quantize_flow_embedder_layers": false,`
238
+ quantize the flow embedder layers in the flow transformer, which improves precision considerably at the cost of adding ~512MB vram usage (default: false)
239
+
240
  ## API Endpoints
241
 
242
  ### Generate Image
 
261
  ### Running the Server
262
 
263
  ```bash
264
+ python main.py --config-path configs/config-dev-1-4090.json --port 8088 --host 0.0.0.0
265
  ```
266
 
267
+ Or if you need more granular control over the all of the settings, you can run the server with something like this:
268
 
269
  ```bash
270
  python main.py --port 8088 --host 0.0.0.0 \
 
280
  --quant-ae
281
  ```
282
 
283
+ ### Generating an image on a client
284
 
285
  Send a POST request to `http://<host>:<port>/generate` with the following JSON body:
286
 
float8_quantize.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torchao.float8.float8_utils import (
@@ -10,7 +11,8 @@ import math
10
  from torch.compiler import is_compiling
11
  from torch import __version__
12
  from torch.version import cuda
13
- from typing import TypeVar
 
14
 
15
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
16
  LT_TORCH_2_4 = __version__ < (2, 4)
@@ -42,7 +44,7 @@ class F8Linear(nn.Module):
42
  float8_dtype=torch.float8_e4m3fn,
43
  float_weight: torch.Tensor = None,
44
  float_bias: torch.Tensor = None,
45
- num_scale_trials: int = 24,
46
  input_float8_dtype=torch.float8_e5m2,
47
  ) -> None:
48
  super().__init__()
@@ -183,6 +185,11 @@ class F8Linear(nn.Module):
183
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
184
  )
185
 
 
 
 
 
 
186
  def quantize_input(self, x: torch.Tensor):
187
  if self.input_scale_initialized:
188
  return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
@@ -279,10 +286,12 @@ class F8Linear(nn.Module):
279
  return f8_lin
280
 
281
 
 
282
  def recursive_swap_linears(
283
  model: nn.Module,
284
  float8_dtype=torch.float8_e4m3fn,
285
  input_float8_dtype=torch.float8_e5m2,
 
286
  ) -> None:
287
  """
288
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
@@ -300,6 +309,8 @@ def recursive_swap_linears(
300
  all linear layers in the model will be using 8-bit quantization.
301
  """
302
  for name, child in model.named_children():
 
 
303
  if isinstance(child, nn.Linear) and not isinstance(
304
  child, (F8Linear, CublasLinear)
305
  ):
@@ -315,7 +326,35 @@ def recursive_swap_linears(
315
  )
316
  del child
317
  else:
318
- recursive_swap_linears(child)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  @torch.inference_mode()
@@ -325,6 +364,10 @@ def quantize_flow_transformer_and_dispatch_float8(
325
  float8_dtype=torch.float8_e4m3fn,
326
  input_float8_dtype=torch.float8_e5m2,
327
  offload_flow=False,
 
 
 
 
328
  ) -> nn.Module:
329
  """
330
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
@@ -334,19 +377,36 @@ def quantize_flow_transformer_and_dispatch_float8(
334
  Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
335
 
336
  After dispatching, if offload_flow is True, offloads the model to cpu.
 
 
 
 
 
 
 
 
 
 
 
337
  """
338
  for module in flow_model.double_blocks:
339
  module.to(device)
340
  module.eval()
341
  recursive_swap_linears(
342
- module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
 
 
 
343
  )
344
  torch.cuda.empty_cache()
345
  for module in flow_model.single_blocks:
346
  module.to(device)
347
  module.eval()
348
  recursive_swap_linears(
349
- module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
 
 
 
350
  )
351
  torch.cuda.empty_cache()
352
  to_gpu_extras = [
@@ -367,23 +427,30 @@ def quantize_flow_transformer_and_dispatch_float8(
367
  if isinstance(m_extra, nn.Linear) and not isinstance(
368
  m_extra, (F8Linear, CublasLinear)
369
  ):
370
- setattr(
371
- flow_model,
372
- module,
373
- F8Linear.from_linear(
 
 
 
 
 
 
 
 
 
 
374
  m_extra,
375
  float8_dtype=float8_dtype,
376
  input_float8_dtype=input_float8_dtype,
377
- ),
378
- )
379
- del m_extra
380
- elif module != "final_layer":
381
- recursive_swap_linears(
382
- m_extra,
383
- float8_dtype=float8_dtype,
384
- input_float8_dtype=input_float8_dtype,
385
- )
386
  torch.cuda.empty_cache()
 
 
 
 
387
  if offload_flow:
388
  flow_model.to("cpu")
389
  torch.cuda.empty_cache()
 
1
+ from loguru import logger
2
  import torch
3
  import torch.nn as nn
4
  from torchao.float8.float8_utils import (
 
11
  from torch.compiler import is_compiling
12
  from torch import __version__
13
  from torch.version import cuda
14
+
15
+ from modules.flux_model import Modulation
16
 
17
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
18
  LT_TORCH_2_4 = __version__ < (2, 4)
 
44
  float8_dtype=torch.float8_e4m3fn,
45
  float_weight: torch.Tensor = None,
46
  float_bias: torch.Tensor = None,
47
+ num_scale_trials: int = 12,
48
  input_float8_dtype=torch.float8_e5m2,
49
  ) -> None:
50
  super().__init__()
 
185
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
186
  )
187
 
188
+ def set_weight_tensor(self, tensor: torch.Tensor):
189
+ self.weight.data = tensor
190
+ self.weight_initialized = False
191
+ self.quantize_weight()
192
+
193
  def quantize_input(self, x: torch.Tensor):
194
  if self.input_scale_initialized:
195
  return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
 
286
  return f8_lin
287
 
288
 
289
+ @torch.inference_mode()
290
  def recursive_swap_linears(
291
  model: nn.Module,
292
  float8_dtype=torch.float8_e4m3fn,
293
  input_float8_dtype=torch.float8_e5m2,
294
+ quantize_modulation: bool = True,
295
  ) -> None:
296
  """
297
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
 
309
  all linear layers in the model will be using 8-bit quantization.
310
  """
311
  for name, child in model.named_children():
312
+ if isinstance(child, Modulation) and not quantize_modulation:
313
+ continue
314
  if isinstance(child, nn.Linear) and not isinstance(
315
  child, (F8Linear, CublasLinear)
316
  ):
 
326
  )
327
  del child
328
  else:
329
+ recursive_swap_linears(
330
+ child,
331
+ float8_dtype=float8_dtype,
332
+ input_float8_dtype=input_float8_dtype,
333
+ quantize_modulation=quantize_modulation,
334
+ )
335
+
336
+
337
+ @torch.inference_mode()
338
+ def swap_to_cublaslinear(model: nn.Module):
339
+ if not isinstance(CublasLinear, torch.nn.Module):
340
+ return
341
+ for name, child in model.named_children():
342
+ if isinstance(child, nn.Linear) and not isinstance(
343
+ child, (F8Linear, CublasLinear)
344
+ ):
345
+ cublas_lin = CublasLinear(
346
+ child.in_features,
347
+ child.out_features,
348
+ bias=child.bias is not None,
349
+ dtype=child.weight.dtype,
350
+ device=child.weight.device,
351
+ )
352
+ cublas_lin.weight.data = child.weight.clone().detach()
353
+ cublas_lin.bias.data = child.bias.clone().detach()
354
+ setattr(model, name, cublas_lin)
355
+ del child
356
+ else:
357
+ swap_to_cublaslinear(child)
358
 
359
 
360
  @torch.inference_mode()
 
364
  float8_dtype=torch.float8_e4m3fn,
365
  input_float8_dtype=torch.float8_e5m2,
366
  offload_flow=False,
367
+ swap_linears_with_cublaslinear=True,
368
+ flow_dtype=torch.float16,
369
+ quantize_modulation: bool = True,
370
+ quantize_flow_embedder_layers: bool = True,
371
  ) -> nn.Module:
372
  """
373
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
 
377
  Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
378
 
379
  After dispatching, if offload_flow is True, offloads the model to cpu.
380
+
381
+ if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs.
382
+ Otherwise will skip the cublaslinear swap.
383
+
384
+ For added extra precision, you can set quantize_flow_embedder_layers to False,
385
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
386
+ at the expense of ~512MB more VRAM usage.
387
+
388
+ For added extra precision, you can set quantize_modulation to False,
389
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
390
+ at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers.
391
  """
392
  for module in flow_model.double_blocks:
393
  module.to(device)
394
  module.eval()
395
  recursive_swap_linears(
396
+ module,
397
+ float8_dtype=float8_dtype,
398
+ input_float8_dtype=input_float8_dtype,
399
+ quantize_modulation=quantize_modulation,
400
  )
401
  torch.cuda.empty_cache()
402
  for module in flow_model.single_blocks:
403
  module.to(device)
404
  module.eval()
405
  recursive_swap_linears(
406
+ module,
407
+ float8_dtype=float8_dtype,
408
+ input_float8_dtype=input_float8_dtype,
409
+ quantize_modulation=quantize_modulation,
410
  )
411
  torch.cuda.empty_cache()
412
  to_gpu_extras = [
 
427
  if isinstance(m_extra, nn.Linear) and not isinstance(
428
  m_extra, (F8Linear, CublasLinear)
429
  ):
430
+ if quantize_flow_embedder_layers:
431
+ setattr(
432
+ flow_model,
433
+ module,
434
+ F8Linear.from_linear(
435
+ m_extra,
436
+ float8_dtype=float8_dtype,
437
+ input_float8_dtype=input_float8_dtype,
438
+ ),
439
+ )
440
+ del m_extra
441
+ elif module != "final_layer":
442
+ if quantize_flow_embedder_layers:
443
+ recursive_swap_linears(
444
  m_extra,
445
  float8_dtype=float8_dtype,
446
  input_float8_dtype=input_float8_dtype,
447
+ quantize_modulation=quantize_modulation,
448
+ )
 
 
 
 
 
 
 
449
  torch.cuda.empty_cache()
450
+ if swap_linears_with_cublaslinear and flow_dtype == torch.float16:
451
+ swap_to_cublaslinear(flow_model)
452
+ elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
453
+ logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
454
  if offload_flow:
455
  flow_model.to("cpu")
456
  torch.cuda.empty_cache()
flux_pipeline.py CHANGED
@@ -31,6 +31,7 @@ from torchvision.transforms import functional as TF
31
  from tqdm import tqdm
32
  from util import (
33
  ModelSpec,
 
34
  into_device,
35
  into_dtype,
36
  load_config_from_path,
@@ -80,29 +81,17 @@ class FluxPipeline:
80
  This class is responsible for preparing input tensors for the Flux model, generating
81
  timesteps and noise, and handling device management for model offloading.
82
  """
 
 
 
 
83
  self.debug = debug
84
  self.name = name
85
- self.device_flux = (
86
- flux_device
87
- if isinstance(flux_device, torch.device)
88
- else torch.device(flux_device)
89
- )
90
- self.device_ae = (
91
- ae_device
92
- if isinstance(ae_device, torch.device)
93
- else torch.device(ae_device)
94
- )
95
- self.device_clip = (
96
- clip_device
97
- if isinstance(clip_device, torch.device)
98
- else torch.device(clip_device)
99
- )
100
- self.device_t5 = (
101
- t5_device
102
- if isinstance(t5_device, torch.device)
103
- else torch.device(t5_device)
104
- )
105
- self.dtype = dtype
106
  self.offload = offload
107
  self.clip: "HFEmbedder" = clip
108
  self.t5: "HFEmbedder" = t5
@@ -116,6 +105,8 @@ class FluxPipeline:
116
  self.offload_text_encoder = config.offload_text_encoder
117
  self.offload_vae = config.offload_vae
118
  self.offload_flow = config.offload_flow
 
 
119
  if not self.offload_flow:
120
  self.model.to(self.device_flux)
121
  if not self.offload_vae:
@@ -124,40 +115,16 @@ class FluxPipeline:
124
  self.clip.to(self.device_clip)
125
  self.t5.to(self.device_t5)
126
 
127
- if self.config.compile_blocks or self.config.compile_extras:
128
- if not self.config.prequantized_flow:
129
- logger.info("Running warmups for compile...")
130
- warmup_dict = dict(
131
- prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
132
- height=768,
133
- width=768,
134
- num_steps=25,
135
- guidance=3.5,
136
- seed=10,
137
- )
138
- self.generate(**warmup_dict)
139
- to_gpu_extras = [
140
- "vector_in",
141
- "img_in",
142
- "txt_in",
143
- "time_in",
144
- "guidance_in",
145
- "final_layer",
146
- "pe_embedder",
147
- ]
148
- if self.config.compile_blocks:
149
- for block in self.model.double_blocks:
150
- block.compile()
151
- for block in self.model.single_blocks:
152
- block.compile()
153
- if self.config.compile_extras:
154
- for extra in to_gpu_extras:
155
- getattr(self.model, extra).compile()
156
-
157
- def set_seed(self, seed: int | None = None) -> torch.Generator:
158
  if isinstance(seed, (int, float)):
159
  seed = int(abs(seed)) % MAX_RAND
160
- self.rng = torch.manual_seed(seed)
161
  elif isinstance(seed, str):
162
  try:
163
  seed = abs(int(seed)) % MAX_RAND
@@ -166,14 +133,71 @@ class FluxPipeline:
166
  f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
167
  )
168
  seed = abs(self.rng.seed()) % MAX_RAND
 
169
  else:
170
  seed = abs(self.rng.seed()) % MAX_RAND
171
- torch.cuda.manual_seed_all(seed)
172
- np.random.seed(seed)
173
- random.seed(seed)
174
- cuda_generator = torch.Generator("cuda").manual_seed(seed)
 
 
175
  return cuda_generator, seed
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  @torch.inference_mode()
178
  def prepare(
179
  self,
@@ -608,12 +632,18 @@ class FluxPipeline:
608
 
609
  @classmethod
610
  def load_pipeline_from_config_path(
611
- cls, path: str, flow_model_path: str = None, debug: bool = False
612
  ) -> "FluxPipeline":
613
  with torch.inference_mode():
614
  config = load_config_from_path(path)
615
  if flow_model_path:
616
  config.ckpt_path = flow_model_path
 
 
 
 
 
 
617
  return cls.load_pipeline_from_config(config, debug=debug)
618
 
619
  @classmethod
@@ -639,7 +669,13 @@ class FluxPipeline:
639
 
640
  if not config.prequantized_flow:
641
  flow_model = quantize_flow_transformer_and_dispatch_float8(
642
- flow_model, flux_device, offload_flow=config.offload_flow
 
 
 
 
 
 
643
  )
644
  else:
645
  flow_model.eval().requires_grad_(False)
 
31
  from tqdm import tqdm
32
  from util import (
33
  ModelSpec,
34
+ ModelVersion,
35
  into_device,
36
  into_dtype,
37
  load_config_from_path,
 
81
  This class is responsible for preparing input tensors for the Flux model, generating
82
  timesteps and noise, and handling device management for model offloading.
83
  """
84
+
85
+ if config is None:
86
+ raise ValueError("ModelSpec config is required!")
87
+
88
  self.debug = debug
89
  self.name = name
90
+ self.device_flux = into_device(flux_device)
91
+ self.device_ae = into_device(ae_device)
92
+ self.device_clip = into_device(clip_device)
93
+ self.device_t5 = into_device(t5_device)
94
+ self.dtype = into_dtype(dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.offload = offload
96
  self.clip: "HFEmbedder" = clip
97
  self.t5: "HFEmbedder" = t5
 
105
  self.offload_text_encoder = config.offload_text_encoder
106
  self.offload_vae = config.offload_vae
107
  self.offload_flow = config.offload_flow
108
+ # If models are not offloaded, move them to the appropriate devices
109
+
110
  if not self.offload_flow:
111
  self.model.to(self.device_flux)
112
  if not self.offload_vae:
 
115
  self.clip.to(self.device_clip)
116
  self.t5.to(self.device_t5)
117
 
118
+ # compile the model if needed
119
+ if config.compile_blocks or config.compile_extras:
120
+ self.compile()
121
+
122
+ def set_seed(
123
+ self, seed: int | None = None, seed_globally: bool = False
124
+ ) -> torch.Generator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if isinstance(seed, (int, float)):
126
  seed = int(abs(seed)) % MAX_RAND
127
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
128
  elif isinstance(seed, str):
129
  try:
130
  seed = abs(int(seed)) % MAX_RAND
 
133
  f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed"
134
  )
135
  seed = abs(self.rng.seed()) % MAX_RAND
136
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
137
  else:
138
  seed = abs(self.rng.seed()) % MAX_RAND
139
+ cuda_generator = torch.Generator("cuda").manual_seed(seed)
140
+
141
+ if seed_globally:
142
+ torch.cuda.manual_seed_all(seed)
143
+ np.random.seed(seed)
144
+ random.seed(seed)
145
  return cuda_generator, seed
146
 
147
+ @torch.inference_mode()
148
+ def compile(self):
149
+ """
150
+ Compiles the model and extras.
151
+
152
+ First, if:
153
+
154
+ - A) Checkpoint which already has float8 quantized weights and tuned input scales.
155
+ In which case, it will not run warmups since it assumes the input scales are already tuned.
156
+
157
+ - B) Checkpoint which has not been quantized, in which case it will be quantized
158
+ and the input scales will be tuned. via running a warmup loop.
159
+ - If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps.
160
+ - If the model is flux-dev, it will run 1 warmup loop for 12 steps.
161
+
162
+ """
163
+
164
+ # Run warmups if the checkpoint is not prequantized
165
+ if not self.config.prequantized_flow:
166
+ logger.info("Running warmups for compile...")
167
+ warmup_dict = dict(
168
+ prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
169
+ height=768,
170
+ width=768,
171
+ num_steps=12,
172
+ guidance=3.5,
173
+ seed=10,
174
+ )
175
+ if self.config.version == ModelVersion.flux_schnell:
176
+ warmup_dict["num_steps"] = 4
177
+ for _ in range(3):
178
+ self.generate(**warmup_dict)
179
+ else:
180
+ self.generate(**warmup_dict)
181
+
182
+ # Compile the model and extras
183
+ to_gpu_extras = [
184
+ "vector_in",
185
+ "img_in",
186
+ "txt_in",
187
+ "time_in",
188
+ "guidance_in",
189
+ "final_layer",
190
+ "pe_embedder",
191
+ ]
192
+ if self.config.compile_blocks:
193
+ for block in self.model.double_blocks:
194
+ block.compile()
195
+ for block in self.model.single_blocks:
196
+ block.compile()
197
+ if self.config.compile_extras:
198
+ for extra in to_gpu_extras:
199
+ getattr(self.model, extra).compile()
200
+
201
  @torch.inference_mode()
202
  def prepare(
203
  self,
 
632
 
633
  @classmethod
634
  def load_pipeline_from_config_path(
635
+ cls, path: str, flow_model_path: str = None, debug: bool = False, **kwargs
636
  ) -> "FluxPipeline":
637
  with torch.inference_mode():
638
  config = load_config_from_path(path)
639
  if flow_model_path:
640
  config.ckpt_path = flow_model_path
641
+ for k, v in kwargs.items():
642
+ if hasattr(config, k):
643
+ logger.info(
644
+ f"Overriding config {k}:{getattr(config, k)} with value {v}"
645
+ )
646
+ setattr(config, k, v)
647
  return cls.load_pipeline_from_config(config, debug=debug)
648
 
649
  @classmethod
 
669
 
670
  if not config.prequantized_flow:
671
  flow_model = quantize_flow_transformer_and_dispatch_float8(
672
+ flow_model,
673
+ flux_device,
674
+ offload_flow=config.offload_flow,
675
+ swap_linears_with_cublaslinear=flux_dtype == torch.float16,
676
+ flow_dtype=flux_dtype,
677
+ quantize_modulation=config.quantize_modulation,
678
+ quantize_flow_embedder_layers=config.quantize_flow_embedder_layers,
679
  )
680
  else:
681
  flow_model.eval().requires_grad_(False)
main.py CHANGED
@@ -129,6 +129,22 @@ def parse_args():
129
  + "and then saving the state_dict as a safetensors file), "
130
  + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return parser.parse_args()
133
 
134
 
@@ -171,6 +187,8 @@ def main():
171
  offload_ae=args.offload_ae,
172
  offload_text_enc=args.offload_text_enc,
173
  prequantized_flow=args.prequantized_flow,
 
 
174
  )
175
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
176
 
 
129
  + "and then saving the state_dict as a safetensors file), "
130
  + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
  )
132
+ parser.add_argument(
133
+ "-nqfm",
134
+ "--no-quantize-flow-modulation",
135
+ action="store_false",
136
+ default=True,
137
+ dest="quantize_modulation",
138
+ help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
139
+ )
140
+ parser.add_argument(
141
+ "-qfl",
142
+ "--quantize-flow-embedder-layers",
143
+ action="store_true",
144
+ default=False,
145
+ dest="quantize_flow_embedder_layers",
146
+ help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
147
+ )
148
  return parser.parse_args()
149
 
150
 
 
187
  offload_ae=args.offload_ae,
188
  offload_text_enc=args.offload_text_enc,
189
  prequantized_flow=args.prequantized_flow,
190
+ quantize_modulation=args.quantize_modulation,
191
+ quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
192
  )
193
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
194
 
modules/conditioner.py CHANGED
@@ -56,6 +56,16 @@ class HFEmbedder(nn.Module):
56
  self.max_length = max_length
57
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
58
 
 
 
 
 
 
 
 
 
 
 
59
  if self.is_clip:
60
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
61
  version, max_length=max_length
@@ -64,11 +74,6 @@ class HFEmbedder(nn.Module):
64
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
65
  version,
66
  **hf_kwargs,
67
- quantization_config=(
68
- auto_quantization_config(quantization_dtype)
69
- if quantization_dtype
70
- else None
71
- ),
72
  )
73
 
74
  else:
@@ -78,11 +83,6 @@ class HFEmbedder(nn.Module):
78
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
79
  version,
80
  **hf_kwargs,
81
- quantization_config=(
82
- auto_quantization_config(quantization_dtype)
83
- if quantization_dtype
84
- else None
85
- ),
86
  )
87
 
88
  def offload(self):
 
56
  self.max_length = max_length
57
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
58
 
59
+ auto_quant_config = (
60
+ auto_quantization_config(quantization_dtype) if quantization_dtype else None
61
+ )
62
+
63
+ # BNB will move to cuda:0 by default if not specified
64
+ if isinstance(auto_quant_config, BitsAndBytesConfig):
65
+ hf_kwargs["device_map"] = {"": self.device.index}
66
+ if auto_quant_config is not None:
67
+ hf_kwargs["quantization_config"] = auto_quant_config
68
+
69
  if self.is_clip:
70
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
71
  version, max_length=max_length
 
74
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
75
  version,
76
  **hf_kwargs,
 
 
 
 
 
77
  )
78
 
79
  else:
 
83
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
84
  version,
85
  **hf_kwargs,
 
 
 
 
 
86
  )
87
 
88
  def offload(self):
modules/flux_model.py CHANGED
@@ -1,7 +1,11 @@
1
  from collections import namedtuple
2
  import os
 
3
  import torch
4
 
 
 
 
5
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
  torch.backends.cuda.matmul.allow_tf32 = True
7
  torch.backends.cudnn.allow_tf32 = True
@@ -14,11 +18,6 @@ from torch import Tensor, nn
14
  from pydantic import BaseModel
15
  from torch.nn import functional as F
16
 
17
- try:
18
- from cublas_ops import CublasLinear
19
- except ImportError:
20
- CublasLinear = nn.Linear
21
-
22
 
23
  class FluxParams(BaseModel):
24
  in_channels: int
@@ -116,11 +115,39 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
116
 
117
 
118
  class MLPEmbedder(nn.Module):
119
- def __init__(self, in_dim: int, hidden_dim: int):
 
 
 
 
120
  super().__init__()
121
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
 
 
 
 
 
 
 
 
 
 
 
 
122
  self.silu = nn.SiLU()
123
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def forward(self, x: Tensor) -> Tensor:
126
  return self.out_layer(self.silu(self.in_layer(x)))
@@ -148,14 +175,38 @@ class QKNorm(torch.nn.Module):
148
 
149
 
150
  class SelfAttention(nn.Module):
151
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
 
 
 
 
 
 
152
  super().__init__()
 
 
153
  self.num_heads = num_heads
154
  head_dim = dim // num_heads
155
 
156
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 
 
 
 
 
 
 
 
157
  self.norm = QKNorm(head_dim)
158
- self.proj = nn.Linear(dim, dim)
 
 
 
 
 
 
 
 
159
  self.K = 3
160
  self.H = self.num_heads
161
  self.KH = self.K * self.H
@@ -178,11 +229,21 @@ ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
178
 
179
 
180
  class Modulation(nn.Module):
181
- def __init__(self, dim: int, double: bool):
182
  super().__init__()
 
 
183
  self.is_double = double
184
  self.multiplier = 6 if double else 3
185
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
 
 
 
 
 
 
 
 
186
  self.act = nn.SiLU()
187
 
188
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
@@ -202,37 +263,83 @@ class DoubleStreamBlock(nn.Module):
202
  mlp_ratio: float,
203
  qkv_bias: bool = False,
204
  dtype: torch.dtype = torch.float16,
 
 
205
  ):
206
  super().__init__()
 
 
207
  self.dtype = dtype
208
 
209
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
210
  self.num_heads = num_heads
211
  self.hidden_size = hidden_size
212
- self.img_mod = Modulation(hidden_size, double=True)
 
 
213
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
214
  self.img_attn = SelfAttention(
215
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
 
 
 
216
  )
217
 
218
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
219
  self.img_mlp = nn.Sequential(
220
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
 
 
 
 
 
 
 
 
221
  nn.GELU(approximate="tanh"),
222
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
 
 
 
 
 
 
 
 
223
  )
224
 
225
- self.txt_mod = Modulation(hidden_size, double=True)
 
 
226
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
  self.txt_attn = SelfAttention(
228
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
 
 
 
229
  )
230
 
231
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
232
  self.txt_mlp = nn.Sequential(
233
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
 
 
 
 
 
 
 
 
234
  nn.GELU(approximate="tanh"),
235
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
 
 
 
 
 
 
 
 
236
  )
237
  self.K = 3
238
  self.H = self.num_heads
@@ -301,8 +408,12 @@ class SingleStreamBlock(nn.Module):
301
  mlp_ratio: float = 4.0,
302
  qk_scale: float | None = None,
303
  dtype: torch.dtype = torch.float16,
 
 
304
  ):
305
  super().__init__()
 
 
306
  self.dtype = dtype
307
  self.hidden_dim = hidden_size
308
  self.num_heads = num_heads
@@ -311,9 +422,25 @@ class SingleStreamBlock(nn.Module):
311
 
312
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
313
  # qkv and mlp_in
314
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
 
 
 
 
 
 
 
 
315
  # proj and mlp_out
316
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
 
 
 
 
 
 
 
 
317
 
318
  self.norm = QKNorm(head_dim)
319
 
@@ -321,7 +448,11 @@ class SingleStreamBlock(nn.Module):
321
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
322
 
323
  self.mlp_act = nn.GELU(approximate="tanh")
324
- self.modulation = Modulation(hidden_size, double=False)
 
 
 
 
325
 
326
  self.K = 3
327
  self.H = self.num_heads
@@ -350,11 +481,11 @@ class LastLayer(nn.Module):
350
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
351
  super().__init__()
352
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
353
- self.linear = CublasLinear(
354
  hidden_size, patch_size * patch_size * out_channels, bias=True
355
  )
356
  self.adaLN_modulation = nn.Sequential(
357
- nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
358
  )
359
 
360
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
@@ -369,50 +500,96 @@ class Flux(nn.Module):
369
  Transformer model for flow matching on sequences.
370
  """
371
 
372
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
373
  super().__init__()
374
 
375
  self.dtype = dtype
376
- self.params = params
377
- self.in_channels = params.in_channels
378
  self.out_channels = self.in_channels
379
- if params.hidden_size % params.num_heads != 0:
 
 
 
 
 
380
  raise ValueError(
381
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
382
  )
383
- pe_dim = params.hidden_size // params.num_heads
384
- if sum(params.axes_dim) != pe_dim:
385
  raise ValueError(
386
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
387
  )
388
- self.hidden_size = params.hidden_size
389
- self.num_heads = params.num_heads
390
  self.pe_embedder = EmbedND(
391
  dim=pe_dim,
392
- theta=params.theta,
393
- axes_dim=params.axes_dim,
394
  dtype=self.dtype,
395
  )
396
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
397
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
398
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  self.guidance_in = (
400
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
401
- if params.guidance_embed
 
 
 
 
 
402
  else nn.Identity()
403
  )
404
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  self.double_blocks = nn.ModuleList(
407
  [
408
  DoubleStreamBlock(
409
  self.hidden_size,
410
  self.num_heads,
411
- mlp_ratio=params.mlp_ratio,
412
- qkv_bias=params.qkv_bias,
413
  dtype=self.dtype,
 
 
414
  )
415
- for _ in range(params.depth)
416
  ]
417
  )
418
 
@@ -421,10 +598,12 @@ class Flux(nn.Module):
421
  SingleStreamBlock(
422
  self.hidden_size,
423
  self.num_heads,
424
- mlp_ratio=params.mlp_ratio,
425
  dtype=self.dtype,
 
 
426
  )
427
- for _ in range(params.depth_single_blocks)
428
  ]
429
  )
430
 
@@ -477,13 +656,17 @@ class Flux(nn.Module):
477
  return img
478
 
479
  @classmethod
480
- def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
 
 
481
  from util import load_config_from_path
482
  from safetensors.torch import load_file
483
 
484
  config = load_config_from_path(path)
485
  with torch.device("meta"):
486
- klass = cls(params=config.params, dtype=dtype).type(dtype)
 
 
487
 
488
  ckpt = load_file(config.ckpt_path, device="cpu")
489
  klass.load_state_dict(ckpt, assign=True)
 
1
  from collections import namedtuple
2
  import os
3
+ from typing import TYPE_CHECKING
4
  import torch
5
 
6
+ if TYPE_CHECKING:
7
+ from util import ModelSpec
8
+
9
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
10
  torch.backends.cuda.matmul.allow_tf32 = True
11
  torch.backends.cudnn.allow_tf32 = True
 
18
  from pydantic import BaseModel
19
  from torch.nn import functional as F
20
 
 
 
 
 
 
21
 
22
  class FluxParams(BaseModel):
23
  in_channels: int
 
115
 
116
 
117
  class MLPEmbedder(nn.Module):
118
+ def __init__(
119
+ self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False
120
+ ):
121
+ from float8_quantize import F8Linear
122
+
123
  super().__init__()
124
+ self.in_layer = (
125
+ nn.Linear(in_dim, hidden_dim, bias=True)
126
+ if not prequantized
127
+ else (
128
+ F8Linear(
129
+ in_features=in_dim,
130
+ out_features=hidden_dim,
131
+ bias=True,
132
+ )
133
+ if quantized
134
+ else nn.Linear(in_dim, hidden_dim, bias=True)
135
+ )
136
+ )
137
  self.silu = nn.SiLU()
138
+ self.out_layer = (
139
+ nn.Linear(hidden_dim, hidden_dim, bias=True)
140
+ if not prequantized
141
+ else (
142
+ F8Linear(
143
+ in_features=hidden_dim,
144
+ out_features=hidden_dim,
145
+ bias=True,
146
+ )
147
+ if quantized
148
+ else nn.Linear(hidden_dim, hidden_dim, bias=True)
149
+ )
150
+ )
151
 
152
  def forward(self, x: Tensor) -> Tensor:
153
  return self.out_layer(self.silu(self.in_layer(x)))
 
175
 
176
 
177
  class SelfAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim: int,
181
+ num_heads: int = 8,
182
+ qkv_bias: bool = False,
183
+ prequantized: bool = False,
184
+ ):
185
  super().__init__()
186
+ from float8_quantize import F8Linear
187
+
188
  self.num_heads = num_heads
189
  head_dim = dim // num_heads
190
 
191
+ self.qkv = (
192
+ nn.Linear(dim, dim * 3, bias=qkv_bias)
193
+ if not prequantized
194
+ else F8Linear(
195
+ in_features=dim,
196
+ out_features=dim * 3,
197
+ bias=qkv_bias,
198
+ )
199
+ )
200
  self.norm = QKNorm(head_dim)
201
+ self.proj = (
202
+ nn.Linear(dim, dim)
203
+ if not prequantized
204
+ else F8Linear(
205
+ in_features=dim,
206
+ out_features=dim,
207
+ bias=True,
208
+ )
209
+ )
210
  self.K = 3
211
  self.H = self.num_heads
212
  self.KH = self.K * self.H
 
229
 
230
 
231
  class Modulation(nn.Module):
232
+ def __init__(self, dim: int, double: bool, quantized_modulation: bool = False):
233
  super().__init__()
234
+ from float8_quantize import F8Linear
235
+
236
  self.is_double = double
237
  self.multiplier = 6 if double else 3
238
+ self.lin = (
239
+ nn.Linear(dim, self.multiplier * dim, bias=True)
240
+ if not quantized_modulation
241
+ else F8Linear(
242
+ in_features=dim,
243
+ out_features=self.multiplier * dim,
244
+ bias=True,
245
+ )
246
+ )
247
  self.act = nn.SiLU()
248
 
249
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
 
263
  mlp_ratio: float,
264
  qkv_bias: bool = False,
265
  dtype: torch.dtype = torch.float16,
266
+ quantized_modulation: bool = False,
267
+ prequantized: bool = False,
268
  ):
269
  super().__init__()
270
+ from float8_quantize import F8Linear
271
+
272
  self.dtype = dtype
273
 
274
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
275
  self.num_heads = num_heads
276
  self.hidden_size = hidden_size
277
+ self.img_mod = Modulation(
278
+ hidden_size, double=True, quantized_modulation=quantized_modulation
279
+ )
280
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
  self.img_attn = SelfAttention(
282
+ dim=hidden_size,
283
+ num_heads=num_heads,
284
+ qkv_bias=qkv_bias,
285
+ prequantized=prequantized,
286
  )
287
 
288
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
  self.img_mlp = nn.Sequential(
290
+ (
291
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
292
+ if not prequantized
293
+ else F8Linear(
294
+ in_features=hidden_size,
295
+ out_features=mlp_hidden_dim,
296
+ bias=True,
297
+ )
298
+ ),
299
  nn.GELU(approximate="tanh"),
300
+ (
301
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
302
+ if not prequantized
303
+ else F8Linear(
304
+ in_features=mlp_hidden_dim,
305
+ out_features=hidden_size,
306
+ bias=True,
307
+ )
308
+ ),
309
  )
310
 
311
+ self.txt_mod = Modulation(
312
+ hidden_size, double=True, quantized_modulation=quantized_modulation
313
+ )
314
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
315
  self.txt_attn = SelfAttention(
316
+ dim=hidden_size,
317
+ num_heads=num_heads,
318
+ qkv_bias=qkv_bias,
319
+ prequantized=prequantized,
320
  )
321
 
322
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
323
  self.txt_mlp = nn.Sequential(
324
+ (
325
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True)
326
+ if not prequantized
327
+ else F8Linear(
328
+ in_features=hidden_size,
329
+ out_features=mlp_hidden_dim,
330
+ bias=True,
331
+ )
332
+ ),
333
  nn.GELU(approximate="tanh"),
334
+ (
335
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True)
336
+ if not prequantized
337
+ else F8Linear(
338
+ in_features=mlp_hidden_dim,
339
+ out_features=hidden_size,
340
+ bias=True,
341
+ )
342
+ ),
343
  )
344
  self.K = 3
345
  self.H = self.num_heads
 
408
  mlp_ratio: float = 4.0,
409
  qk_scale: float | None = None,
410
  dtype: torch.dtype = torch.float16,
411
+ quantized_modulation: bool = False,
412
+ prequantized: bool = False,
413
  ):
414
  super().__init__()
415
+ from float8_quantize import F8Linear
416
+
417
  self.dtype = dtype
418
  self.hidden_dim = hidden_size
419
  self.num_heads = num_heads
 
422
 
423
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
424
  # qkv and mlp_in
425
+ self.linear1 = (
426
+ nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
427
+ if not prequantized
428
+ else F8Linear(
429
+ in_features=hidden_size,
430
+ out_features=hidden_size * 3 + self.mlp_hidden_dim,
431
+ bias=True,
432
+ )
433
+ )
434
  # proj and mlp_out
435
+ self.linear2 = (
436
+ nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
437
+ if not prequantized
438
+ else F8Linear(
439
+ in_features=hidden_size + self.mlp_hidden_dim,
440
+ out_features=hidden_size,
441
+ bias=True,
442
+ )
443
+ )
444
 
445
  self.norm = QKNorm(head_dim)
446
 
 
448
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
449
 
450
  self.mlp_act = nn.GELU(approximate="tanh")
451
+ self.modulation = Modulation(
452
+ hidden_size,
453
+ double=False,
454
+ quantized_modulation=quantized_modulation and prequantized,
455
+ )
456
 
457
  self.K = 3
458
  self.H = self.num_heads
 
481
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
482
  super().__init__()
483
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
484
+ self.linear = nn.Linear(
485
  hidden_size, patch_size * patch_size * out_channels, bias=True
486
  )
487
  self.adaLN_modulation = nn.Sequential(
488
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
489
  )
490
 
491
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
 
500
  Transformer model for flow matching on sequences.
501
  """
502
 
503
+ def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16):
504
  super().__init__()
505
 
506
  self.dtype = dtype
507
+ self.params = config.params
508
+ self.in_channels = config.params.in_channels
509
  self.out_channels = self.in_channels
510
+ prequantized_flow = config.prequantized_flow
511
+ quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
512
+ quantized_modulation = config.quantize_modulation and prequantized_flow
513
+ from float8_quantize import F8Linear
514
+
515
+ if config.params.hidden_size % config.params.num_heads != 0:
516
  raise ValueError(
517
+ f"Hidden size {config.params.hidden_size} must be divisible by num_heads {config.params.num_heads}"
518
  )
519
+ pe_dim = config.params.hidden_size // config.params.num_heads
520
+ if sum(config.params.axes_dim) != pe_dim:
521
  raise ValueError(
522
+ f"Got {config.params.axes_dim} but expected positional dim {pe_dim}"
523
  )
524
+ self.hidden_size = config.params.hidden_size
525
+ self.num_heads = config.params.num_heads
526
  self.pe_embedder = EmbedND(
527
  dim=pe_dim,
528
+ theta=config.params.theta,
529
+ axes_dim=config.params.axes_dim,
530
  dtype=self.dtype,
531
  )
532
+ self.img_in = (
533
+ nn.Linear(self.in_channels, self.hidden_size, bias=True)
534
+ if not prequantized_flow
535
+ else (
536
+ F8Linear(
537
+ in_features=self.in_channels,
538
+ out_features=self.hidden_size,
539
+ bias=True,
540
+ )
541
+ if quantized_embedders
542
+ else nn.Linear(self.in_channels, self.hidden_size, bias=True)
543
+ )
544
+ )
545
+ self.time_in = MLPEmbedder(
546
+ in_dim=256,
547
+ hidden_dim=self.hidden_size,
548
+ prequantized=prequantized_flow,
549
+ quantized=quantized_embedders,
550
+ )
551
+ self.vector_in = MLPEmbedder(
552
+ config.params.vec_in_dim,
553
+ self.hidden_size,
554
+ prequantized=prequantized_flow,
555
+ quantized=quantized_embedders,
556
+ )
557
  self.guidance_in = (
558
+ MLPEmbedder(
559
+ in_dim=256,
560
+ hidden_dim=self.hidden_size,
561
+ prequantized=prequantized_flow,
562
+ quantized=quantized_embedders,
563
+ )
564
+ if config.params.guidance_embed
565
  else nn.Identity()
566
  )
567
+ self.txt_in = (
568
+ nn.Linear(config.params.context_in_dim, self.hidden_size)
569
+ if not quantized_embedders
570
+ else (
571
+ F8Linear(
572
+ in_features=config.params.context_in_dim,
573
+ out_features=self.hidden_size,
574
+ bias=True,
575
+ )
576
+ if quantized_embedders
577
+ else nn.Linear(config.params.context_in_dim, self.hidden_size)
578
+ )
579
+ )
580
 
581
  self.double_blocks = nn.ModuleList(
582
  [
583
  DoubleStreamBlock(
584
  self.hidden_size,
585
  self.num_heads,
586
+ mlp_ratio=config.params.mlp_ratio,
587
+ qkv_bias=config.params.qkv_bias,
588
  dtype=self.dtype,
589
+ quantized_modulation=quantized_modulation,
590
+ prequantized=prequantized_flow,
591
  )
592
+ for _ in range(config.params.depth)
593
  ]
594
  )
595
 
 
598
  SingleStreamBlock(
599
  self.hidden_size,
600
  self.num_heads,
601
+ mlp_ratio=config.params.mlp_ratio,
602
  dtype=self.dtype,
603
+ quantized_modulation=quantized_modulation,
604
+ prequantized=prequantized_flow,
605
  )
606
+ for _ in range(config.params.depth_single_blocks)
607
  ]
608
  )
609
 
 
656
  return img
657
 
658
  @classmethod
659
+ def from_pretrained(
660
+ cls: "Flux", path: str, dtype: torch.dtype = torch.float16
661
+ ) -> "Flux":
662
  from util import load_config_from_path
663
  from safetensors.torch import load_file
664
 
665
  config = load_config_from_path(path)
666
  with torch.device("meta"):
667
+ klass = cls(config=config, dtype=dtype)
668
+ if not config.prequantized_flow:
669
+ klass.type(dtype)
670
 
671
  ckpt = load_file(config.ckpt_path, device="cpu")
672
  klass.load_state_dict(ckpt, assign=True)
modules/flux_model_f8.py DELETED
@@ -1,491 +0,0 @@
1
- from collections import namedtuple
2
- import os
3
- import torch
4
-
5
- DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cudnn.allow_tf32 = True
8
- torch.backends.cudnn.benchmark = True
9
- torch.backends.cudnn.benchmark_limit = 20
10
- torch.set_float32_matmul_precision("high")
11
- import math
12
-
13
- from torch import Tensor, nn
14
- from pydantic import BaseModel
15
- from torch.nn import functional as F
16
- from float8_quantize import F8Linear
17
-
18
- try:
19
- from cublas_ops import CublasLinear
20
- except ImportError:
21
- CublasLinear = nn.Linear
22
-
23
-
24
- class FluxParams(BaseModel):
25
- in_channels: int
26
- vec_in_dim: int
27
- context_in_dim: int
28
- hidden_size: int
29
- mlp_ratio: float
30
- num_heads: int
31
- depth: int
32
- depth_single_blocks: int
33
- axes_dim: list[int]
34
- theta: int
35
- qkv_bias: bool
36
- guidance_embed: bool
37
-
38
-
39
- # attention is always same shape each time it's called per H*W, so compile with fullgraph
40
- # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
41
- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
42
- q, k = apply_rope(q, k, pe)
43
- x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
44
- x = x.reshape(*x.shape[:-2], -1)
45
- return x
46
-
47
-
48
- # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
49
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
- scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
- omega = 1.0 / (theta**scale)
52
- out = torch.einsum("...n,d->...nd", pos, omega)
53
- out = torch.stack(
54
- [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
55
- )
56
- out = out.reshape(*out.shape[:-1], 2, 2)
57
- return out
58
-
59
-
60
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
61
- xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
62
- xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
63
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
64
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
65
- return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
66
-
67
-
68
- class EmbedND(nn.Module):
69
- def __init__(
70
- self,
71
- dim: int,
72
- theta: int,
73
- axes_dim: list[int],
74
- dtype: torch.dtype = torch.bfloat16,
75
- ):
76
- super().__init__()
77
- self.dim = dim
78
- self.theta = theta
79
- self.axes_dim = axes_dim
80
- self.dtype = dtype
81
-
82
- def forward(self, ids: Tensor) -> Tensor:
83
- n_axes = ids.shape[-1]
84
- emb = torch.cat(
85
- [
86
- rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
87
- for i in range(n_axes)
88
- ],
89
- dim=-3,
90
- )
91
-
92
- return emb.unsqueeze(1)
93
-
94
-
95
- def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
96
- """
97
- Create sinusoidal timestep embeddings.
98
- :param t: a 1-D Tensor of N indices, one per batch element.
99
- These may be fractional.
100
- :param dim: the dimension of the output.
101
- :param max_period: controls the minimum frequency of the embeddings.
102
- :return: an (N, D) Tensor of positional embeddings.
103
- """
104
- t = time_factor * t
105
- half = dim // 2
106
- freqs = torch.exp(
107
- -math.log(max_period)
108
- * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
109
- / half
110
- )
111
-
112
- args = t[:, None].float() * freqs[None]
113
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
- if dim % 2:
115
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
- return embedding
117
-
118
-
119
- class MLPEmbedder(nn.Module):
120
- def __init__(self, in_dim: int, hidden_dim: int):
121
- super().__init__()
122
- self.in_layer = F8Linear(in_dim, hidden_dim, bias=True)
123
- self.silu = nn.SiLU()
124
- self.out_layer = F8Linear(hidden_dim, hidden_dim, bias=True)
125
-
126
- def forward(self, x: Tensor) -> Tensor:
127
- return self.out_layer(self.silu(self.in_layer(x)))
128
-
129
-
130
- class RMSNorm(torch.nn.Module):
131
- def __init__(self, dim: int):
132
- super().__init__()
133
- self.scale = nn.Parameter(torch.ones(dim))
134
-
135
- def forward(self, x: Tensor):
136
- return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
137
-
138
-
139
- class QKNorm(torch.nn.Module):
140
- def __init__(self, dim: int):
141
- super().__init__()
142
- self.query_norm = RMSNorm(dim)
143
- self.key_norm = RMSNorm(dim)
144
-
145
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
146
- q = self.query_norm(q)
147
- k = self.key_norm(k)
148
- return q, k
149
-
150
-
151
- class SelfAttention(nn.Module):
152
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
153
- super().__init__()
154
- self.num_heads = num_heads
155
- head_dim = dim // num_heads
156
-
157
- self.qkv = F8Linear(dim, dim * 3, bias=qkv_bias)
158
- self.norm = QKNorm(head_dim)
159
- self.proj = F8Linear(dim, dim)
160
- self.K = 3
161
- self.H = self.num_heads
162
- self.KH = self.K * self.H
163
-
164
- def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
165
- B, L, D = x.shape
166
- q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
167
- return q, k, v
168
-
169
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
170
- qkv = self.qkv(x)
171
- q, k, v = self.rearrange_for_norm(qkv)
172
- q, k = self.norm(q, k, v)
173
- x = attention(q, k, v, pe=pe)
174
- x = self.proj(x)
175
- return x
176
-
177
-
178
- ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
179
-
180
-
181
- class Modulation(nn.Module):
182
- def __init__(self, dim: int, double: bool):
183
- super().__init__()
184
- self.is_double = double
185
- self.multiplier = 6 if double else 3
186
- self.lin = F8Linear(dim, self.multiplier * dim, bias=True)
187
- self.act = nn.SiLU()
188
-
189
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
190
- out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
191
-
192
- return (
193
- ModulationOut(*out[:3]),
194
- ModulationOut(*out[3:]) if self.is_double else None,
195
- )
196
-
197
-
198
- class DoubleStreamBlock(nn.Module):
199
- def __init__(
200
- self,
201
- hidden_size: int,
202
- num_heads: int,
203
- mlp_ratio: float,
204
- qkv_bias: bool = False,
205
- dtype: torch.dtype = torch.float16,
206
- ):
207
- super().__init__()
208
- self.dtype = dtype
209
-
210
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
211
- self.num_heads = num_heads
212
- self.hidden_size = hidden_size
213
- self.img_mod = Modulation(hidden_size, double=True)
214
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
215
- self.img_attn = SelfAttention(
216
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
217
- )
218
-
219
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
220
- self.img_mlp = nn.Sequential(
221
- F8Linear(hidden_size, mlp_hidden_dim, bias=True),
222
- nn.GELU(approximate="tanh"),
223
- F8Linear(mlp_hidden_dim, hidden_size, bias=True),
224
- )
225
-
226
- self.txt_mod = Modulation(hidden_size, double=True)
227
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
228
- self.txt_attn = SelfAttention(
229
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
230
- )
231
-
232
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
- self.txt_mlp = nn.Sequential(
234
- F8Linear(hidden_size, mlp_hidden_dim, bias=True),
235
- nn.GELU(approximate="tanh"),
236
- F8Linear(mlp_hidden_dim, hidden_size, bias=True),
237
- )
238
- self.K = 3
239
- self.H = self.num_heads
240
- self.KH = self.K * self.H
241
-
242
- def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
243
- B, L, D = x.shape
244
- q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
245
- return q, k, v
246
-
247
- def forward(
248
- self,
249
- img: Tensor,
250
- txt: Tensor,
251
- vec: Tensor,
252
- pe: Tensor,
253
- ) -> tuple[Tensor, Tensor]:
254
- img_mod1, img_mod2 = self.img_mod(vec)
255
- txt_mod1, txt_mod2 = self.txt_mod(vec)
256
-
257
- # prepare image for attention
258
- img_modulated = self.img_norm1(img)
259
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
260
- img_qkv = self.img_attn.qkv(img_modulated)
261
- img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
262
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
263
-
264
- # prepare txt for attention
265
- txt_modulated = self.txt_norm1(txt)
266
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
267
- txt_qkv = self.txt_attn.qkv(txt_modulated)
268
- txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
269
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
270
-
271
- q = torch.cat((txt_q, img_q), dim=2)
272
- k = torch.cat((txt_k, img_k), dim=2)
273
- v = torch.cat((txt_v, img_v), dim=2)
274
-
275
- attn = attention(q, k, v, pe=pe)
276
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
277
- # calculate the img bloks
278
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
279
- img = img + img_mod2.gate * self.img_mlp(
280
- (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
281
- ).clamp(min=-384 * 2, max=384 * 2)
282
-
283
- # calculate the txt bloks
284
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
285
- txt = txt + txt_mod2.gate * self.txt_mlp(
286
- (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
287
- ).clamp(min=-384 * 2, max=384 * 2)
288
-
289
- return img, txt
290
-
291
-
292
- class SingleStreamBlock(nn.Module):
293
- """
294
- A DiT block with parallel linear layers as described in
295
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
296
- """
297
-
298
- def __init__(
299
- self,
300
- hidden_size: int,
301
- num_heads: int,
302
- mlp_ratio: float = 4.0,
303
- qk_scale: float | None = None,
304
- dtype: torch.dtype = torch.float16,
305
- ):
306
- super().__init__()
307
- self.dtype = dtype
308
- self.hidden_dim = hidden_size
309
- self.num_heads = num_heads
310
- head_dim = hidden_size // num_heads
311
- self.scale = qk_scale or head_dim**-0.5
312
-
313
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
314
- # qkv and mlp_in
315
- self.linear1 = F8Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
316
- # proj and mlp_out
317
- self.linear2 = F8Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
318
-
319
- self.norm = QKNorm(head_dim)
320
-
321
- self.hidden_size = hidden_size
322
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
323
-
324
- self.mlp_act = nn.GELU(approximate="tanh")
325
- self.modulation = Modulation(hidden_size, double=False)
326
-
327
- self.K = 3
328
- self.H = self.num_heads
329
- self.KH = self.K * self.H
330
-
331
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
332
- mod = self.modulation(vec)[0]
333
- pre_norm = self.pre_norm(x)
334
- x_mod = (1 + mod.scale) * pre_norm + mod.shift
335
- qkv, mlp = torch.split(
336
- self.linear1(x_mod),
337
- [3 * self.hidden_size, self.mlp_hidden_dim],
338
- dim=-1,
339
- )
340
- B, L, D = qkv.shape
341
- q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
342
- q, k = self.norm(q, k, v)
343
- attn = attention(q, k, v, pe=pe)
344
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
345
- min=-384 * 4, max=384 * 4
346
- )
347
- return x + mod.gate * output
348
-
349
-
350
- class LastLayer(nn.Module):
351
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
352
- super().__init__()
353
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
354
- self.linear = CublasLinear(
355
- hidden_size, patch_size * patch_size * out_channels, bias=True
356
- )
357
- self.adaLN_modulation = nn.Sequential(
358
- nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
359
- )
360
-
361
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
362
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
363
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
364
- x = self.linear(x)
365
- return x
366
-
367
-
368
- class Flux(nn.Module):
369
- """
370
- Transformer model for flow matching on sequences.
371
- """
372
-
373
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
374
- super().__init__()
375
-
376
- self.dtype = dtype
377
- self.params = params
378
- self.in_channels = params.in_channels
379
- self.out_channels = self.in_channels
380
- if params.hidden_size % params.num_heads != 0:
381
- raise ValueError(
382
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
383
- )
384
- pe_dim = params.hidden_size // params.num_heads
385
- if sum(params.axes_dim) != pe_dim:
386
- raise ValueError(
387
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
388
- )
389
- self.hidden_size = params.hidden_size
390
- self.num_heads = params.num_heads
391
- self.pe_embedder = EmbedND(
392
- dim=pe_dim,
393
- theta=params.theta,
394
- axes_dim=params.axes_dim,
395
- dtype=self.dtype,
396
- )
397
- self.img_in = F8Linear(self.in_channels, self.hidden_size, bias=True)
398
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
399
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
400
- self.guidance_in = (
401
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
402
- if params.guidance_embed
403
- else nn.Identity()
404
- )
405
- self.txt_in = F8Linear(params.context_in_dim, self.hidden_size)
406
-
407
- self.double_blocks = nn.ModuleList(
408
- [
409
- DoubleStreamBlock(
410
- self.hidden_size,
411
- self.num_heads,
412
- mlp_ratio=params.mlp_ratio,
413
- qkv_bias=params.qkv_bias,
414
- dtype=self.dtype,
415
- )
416
- for _ in range(params.depth)
417
- ]
418
- )
419
-
420
- self.single_blocks = nn.ModuleList(
421
- [
422
- SingleStreamBlock(
423
- self.hidden_size,
424
- self.num_heads,
425
- mlp_ratio=params.mlp_ratio,
426
- dtype=self.dtype,
427
- )
428
- for _ in range(params.depth_single_blocks)
429
- ]
430
- )
431
-
432
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
433
-
434
- def forward(
435
- self,
436
- img: Tensor,
437
- img_ids: Tensor,
438
- txt: Tensor,
439
- txt_ids: Tensor,
440
- timesteps: Tensor,
441
- y: Tensor,
442
- guidance: Tensor | None = None,
443
- ) -> Tensor:
444
- if img.ndim != 3 or txt.ndim != 3:
445
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
446
-
447
- # running on sequences img
448
- img = self.img_in(img)
449
- vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
450
-
451
- if self.params.guidance_embed:
452
- if guidance is None:
453
- raise ValueError(
454
- "Didn't get guidance strength for guidance distilled model."
455
- )
456
- vec = vec + self.guidance_in(
457
- timestep_embedding(guidance, 256).type(self.dtype)
458
- )
459
- vec = vec + self.vector_in(y)
460
-
461
- txt = self.txt_in(txt)
462
-
463
- ids = torch.cat((txt_ids, img_ids), dim=1)
464
- pe = self.pe_embedder(ids)
465
-
466
- # double stream blocks
467
- for block in self.double_blocks:
468
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
469
-
470
- img = torch.cat((txt, img), 1)
471
-
472
- # single stream blocks
473
- for block in self.single_blocks:
474
- img = block(img, vec=vec, pe=pe)
475
-
476
- img = img[:, txt.shape[1] :, ...]
477
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
478
- return img
479
-
480
- @classmethod
481
- def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
482
- from util import load_config_from_path
483
- from safetensors.torch import load_file
484
-
485
- config = load_config_from_path(path)
486
- with torch.device("meta"):
487
- klass = cls(params=config.params, dtype=dtype).type(dtype)
488
-
489
- ckpt = load_file(config.ckpt_path, device="cpu")
490
- klass.load_state_dict(ckpt, assign=True)
491
- return klass.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -6,14 +6,17 @@ import torch
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
9
- from modules.flux_model_f8 import Flux as FluxF8
10
  from safetensors.torch import load_file as load_sft
 
11
  try:
12
  from enum import StrEnum
13
  except:
14
  from enum import Enum
 
15
  class StrEnum(str, Enum):
16
  pass
 
 
17
  from pydantic import BaseModel, ConfigDict
18
  from loguru import logger
19
 
@@ -61,6 +64,11 @@ class ModelSpec(BaseModel):
61
  offload_flow: bool = False
62
  prequantized_flow: bool = False
63
 
 
 
 
 
 
64
  model_config: ConfigDict = {
65
  "arbitrary_types_allowed": True,
66
  "use_enum_values": True,
@@ -84,6 +92,8 @@ def parse_device(device: str | torch.device | None) -> torch.device:
84
 
85
 
86
  def into_dtype(dtype: str) -> torch.dtype:
 
 
87
  if dtype == "float16":
88
  return torch.float16
89
  elif dtype == "bfloat16":
@@ -125,6 +135,8 @@ def load_config(
125
  quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
126
  quant_ae: bool = False,
127
  prequantized_flow: bool = False,
 
 
128
  ) -> ModelSpec:
129
  """
130
  Load a model configuration using the passed arguments.
@@ -192,6 +204,8 @@ def load_config(
192
  }.get(quant_text_enc, None),
193
  ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
194
  prequantized_flow=prequantized_flow,
 
 
195
  )
196
 
197
 
@@ -219,16 +233,14 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
219
  )
220
 
221
 
222
- def load_flow_model(config: ModelSpec) -> Flux | FluxF8:
223
  ckpt_path = config.ckpt_path
224
  FluxClass = Flux
225
- if config.prequantized_flow:
226
- FluxClass = FluxF8
227
 
228
  with torch.device("meta"):
229
- model = FluxClass(config.params, dtype=into_dtype(config.flow_dtype)).type(
230
- into_dtype(config.flow_dtype)
231
- )
232
 
233
  if ckpt_path is not None:
234
  # load_sft doesn't support torch.device
@@ -279,7 +291,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
279
 
280
 
281
  class LoadedModels(BaseModel):
282
- flow: Flux | FluxF8
283
  ae: AutoEncoder
284
  clip: HFEmbedder
285
  t5: HFEmbedder
 
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
 
9
  from safetensors.torch import load_file as load_sft
10
+
11
  try:
12
  from enum import StrEnum
13
  except:
14
  from enum import Enum
15
+
16
  class StrEnum(str, Enum):
17
  pass
18
+
19
+
20
  from pydantic import BaseModel, ConfigDict
21
  from loguru import logger
22
 
 
64
  offload_flow: bool = False
65
  prequantized_flow: bool = False
66
 
67
+ # Improved precision via not quanitzing the modulation linear layers
68
+ quantize_modulation: bool = True
69
+ # Improved precision via not quanitzing the flow embedder layers
70
+ quantize_flow_embedder_layers: bool = False
71
+
72
  model_config: ConfigDict = {
73
  "arbitrary_types_allowed": True,
74
  "use_enum_values": True,
 
92
 
93
 
94
  def into_dtype(dtype: str) -> torch.dtype:
95
+ if isinstance(dtype, torch.dtype):
96
+ return dtype
97
  if dtype == "float16":
98
  return torch.float16
99
  elif dtype == "bfloat16":
 
135
  quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
136
  quant_ae: bool = False,
137
  prequantized_flow: bool = False,
138
+ quantize_modulation: bool = True,
139
+ quantize_flow_embedder_layers: bool = False,
140
  ) -> ModelSpec:
141
  """
142
  Load a model configuration using the passed arguments.
 
204
  }.get(quant_text_enc, None),
205
  ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
206
  prequantized_flow=prequantized_flow,
207
+ quantize_modulation=quantize_modulation,
208
+ quantize_flow_embedder_layers=quantize_flow_embedder_layers,
209
  )
210
 
211
 
 
233
  )
234
 
235
 
236
+ def load_flow_model(config: ModelSpec) -> Flux:
237
  ckpt_path = config.ckpt_path
238
  FluxClass = Flux
 
 
239
 
240
  with torch.device("meta"):
241
+ model = FluxClass(config, dtype=into_dtype(config.flow_dtype))
242
+ if not config.prequantized_flow:
243
+ model.type(into_dtype(config.flow_dtype))
244
 
245
  if ckpt_path is not None:
246
  # load_sft doesn't support torch.device
 
291
 
292
 
293
  class LoadedModels(BaseModel):
294
+ flow: Flux
295
  ae: AutoEncoder
296
  clip: HFEmbedder
297
  t5: HFEmbedder