Safetensors
aredden commited on
Commit
28dec30
·
1 Parent(s): c4a514f

Add offloading & improved fp8 inference.

Browse files
configs/config-dev-eval.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "num_to_quant": 22,
53
+ "compile_extras": false,
54
+ "compile_blocks": false
55
+ }
configs/config-dev-offload.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "num_to_quant": 22,
53
+ "compile_extras": false,
54
+ "compile_blocks": false,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": true
58
+ }
configs/config-dev.json CHANGED
@@ -50,6 +50,6 @@
50
  "flow_quantization_dtype": "qfloat8",
51
  "text_enc_quantization_dtype": "qfloat8",
52
  "num_to_quant": 22,
53
- "compile_extras": false,
54
- "compile_blocks": false
55
  }
 
50
  "flow_quantization_dtype": "qfloat8",
51
  "text_enc_quantization_dtype": "qfloat8",
52
  "num_to_quant": 22,
53
+ "compile_extras": true,
54
+ "compile_blocks": true
55
  }
float8_quantize.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchao.float8.float8_utils import (
4
+ amax_to_scale,
5
+ tensor_to_amax,
6
+ to_fp8_saturated,
7
+ )
8
+ from torch.nn import init
9
+ import math
10
+ from torch.compiler import is_compiling
11
+
12
+
13
+ try:
14
+ from cublas_ops import CublasLinear
15
+ except ImportError:
16
+ CublasLinear = type(None)
17
+
18
+
19
+ class F8Linear(nn.Module):
20
+
21
+ def __init__(
22
+ self,
23
+ in_features: int,
24
+ out_features: int,
25
+ bias: bool = True,
26
+ device=None,
27
+ dtype=None,
28
+ float8_dtype=torch.float8_e4m3fn,
29
+ float_weight: torch.Tensor = None,
30
+ float_bias: torch.Tensor = None,
31
+ num_scale_trials: int = 24,
32
+ input_float8_dtype=torch.float8_e5m2,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.in_features = in_features
36
+ self.out_features = out_features
37
+ self.float8_dtype = float8_dtype
38
+ self.input_float8_dtype = input_float8_dtype
39
+ self.input_scale_initialized = False
40
+ self.weight_initialized = False
41
+ self.max_value = torch.finfo(self.float8_dtype).max
42
+ self.input_max_value = torch.finfo(self.input_float8_dtype).max
43
+ factory_kwargs = {"dtype": dtype, "device": device}
44
+ if float_weight is None:
45
+ self.weight = nn.Parameter(
46
+ torch.empty((out_features, in_features), **factory_kwargs)
47
+ )
48
+ else:
49
+ self.weight = nn.Parameter(
50
+ float_weight, requires_grad=float_weight.requires_grad
51
+ )
52
+ if float_bias is None:
53
+ if bias:
54
+ self.bias = nn.Parameter(
55
+ torch.empty(out_features, **factory_kwargs),
56
+ requires_grad=bias.requires_grad,
57
+ )
58
+ else:
59
+ self.register_parameter("bias", None)
60
+ else:
61
+ self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad)
62
+ self.num_scale_trials = num_scale_trials
63
+ self.input_amax_trials = torch.zeros(
64
+ num_scale_trials, requires_grad=False, device=device, dtype=torch.float32
65
+ )
66
+ self.trial_index = 0
67
+ self.register_buffer("scale", None)
68
+ self.register_buffer(
69
+ "input_scale",
70
+ None,
71
+ )
72
+ self.register_buffer(
73
+ "float8_data",
74
+ None,
75
+ )
76
+ self.scale_reciprocal = self.register_buffer("scale_reciprocal", None)
77
+ self.input_scale_reciprocal = self.register_buffer(
78
+ "input_scale_reciprocal", None
79
+ )
80
+
81
+ def quantize_weight(self):
82
+ if self.weight_initialized:
83
+ return
84
+ amax = tensor_to_amax(self.weight.data)
85
+ scale = amax_to_scale(amax, self.float8_dtype, self.weight.dtype)
86
+ self.float8_data = to_fp8_saturated(self.weight.data * scale, self.float8_dtype)
87
+ self.scale = scale.float()
88
+ self.weight_initialized = True
89
+ self.scale_reciprocal = self.scale.reciprocal().float()
90
+ self.weight.data = torch.zeros(
91
+ 1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
92
+ )
93
+
94
+ def quantize_input(self, x: torch.Tensor):
95
+ if self.input_scale_initialized:
96
+ return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
97
+ elif self.trial_index < self.num_scale_trials:
98
+ amax = tensor_to_amax(x)
99
+ self.input_amax_trials[self.trial_index] = amax
100
+ self.trial_index += 1
101
+ self.input_scale = amax_to_scale(
102
+ self.input_amax_trials[: self.trial_index].max(),
103
+ self.input_float8_dtype,
104
+ self.weight.dtype,
105
+ )
106
+ self.input_scale_reciprocal = self.input_scale.reciprocal()
107
+ return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
108
+ else:
109
+ self.input_scale = amax_to_scale(
110
+ self.input_amax_trials.max(), self.input_float8_dtype, self.weight.dtype
111
+ )
112
+ self.input_scale_reciprocal = self.input_scale.reciprocal()
113
+ self.input_scale_initialized = True
114
+ return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
115
+
116
+ def reset_parameters(self) -> None:
117
+ if self.weight_initialized:
118
+ self.weight = nn.Parameter(
119
+ torch.empty(
120
+ (self.out_features, self.in_features),
121
+ **{
122
+ "dtype": self.weight.dtype,
123
+ "device": self.weight.device,
124
+ },
125
+ )
126
+ )
127
+ self.weight_initialized = False
128
+ self.input_scale_initialized = False
129
+ self.trial_index = 0
130
+ self.input_amax_trials.zero_()
131
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
132
+ if self.bias is not None:
133
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
134
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
135
+ init.uniform_(self.bias, -bound, bound)
136
+ self.quantize_weight()
137
+ self.max_value = torch.finfo(self.float8_dtype).max
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ if self.input_scale_initialized or is_compiling():
141
+ x = (
142
+ x.mul(self.input_scale)
143
+ .clamp(min=-self.input_max_value, max=self.input_max_value)
144
+ .type(self.input_float8_dtype)
145
+ )
146
+ else:
147
+ x = self.quantize_input(x)
148
+
149
+ prev_dims = x.shape[:-1]
150
+
151
+ x = x.view(-1, self.in_features)
152
+
153
+ # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
154
+ return torch._scaled_mm(
155
+ x,
156
+ self.float8_data.T,
157
+ self.input_scale_reciprocal,
158
+ self.scale_reciprocal,
159
+ bias=self.bias,
160
+ out_dtype=self.weight.dtype,
161
+ use_fast_accum=True,
162
+ ).view(*prev_dims, self.out_features)
163
+
164
+ @classmethod
165
+ def from_linear(
166
+ cls,
167
+ linear: nn.Linear,
168
+ float8_dtype=torch.float8_e4m3fn,
169
+ input_float8_dtype=torch.float8_e5m2,
170
+ ):
171
+ f8_lin = cls(
172
+ in_features=linear.in_features,
173
+ out_features=linear.out_features,
174
+ bias=linear.bias is not None,
175
+ device=linear.weight.device,
176
+ dtype=linear.weight.dtype,
177
+ float8_dtype=float8_dtype,
178
+ float_weight=linear.weight.data,
179
+ float_bias=(linear.bias.data if linear.bias is not None else None),
180
+ input_float8_dtype=input_float8_dtype,
181
+ )
182
+ f8_lin.quantize_weight()
183
+ return f8_lin
184
+
185
+
186
+ def recursive_swap_linears(
187
+ model: nn.Module,
188
+ float8_dtype=torch.float8_e4m3fn,
189
+ input_float8_dtype=torch.float8_e5m2,
190
+ ):
191
+ """
192
+ Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
193
+
194
+ This function traverses the model's structure and replaces each nn.Linear
195
+ instance with an F8Linear instance, which uses 8-bit floating point
196
+ quantization for weights. The original linear layer's weights are deleted
197
+ after conversion to save memory.
198
+
199
+ Args:
200
+ model (nn.Module): The PyTorch model to modify.
201
+
202
+ Note:
203
+ This function modifies the model in-place. After calling this function,
204
+ all linear layers in the model will be using 8-bit quantization.
205
+ """
206
+ for name, child in model.named_children():
207
+ if isinstance(child, nn.Linear) and not isinstance(
208
+ child, (F8Linear, CublasLinear)
209
+ ):
210
+
211
+ setattr(
212
+ model,
213
+ name,
214
+ F8Linear.from_linear(
215
+ child,
216
+ float8_dtype=float8_dtype,
217
+ input_float8_dtype=input_float8_dtype,
218
+ ),
219
+ )
220
+ del child
221
+ else:
222
+ recursive_swap_linears(child)
223
+
224
+
225
+ @torch.inference_mode()
226
+ def quantize_flow_transformer_and_dispatch_float8(
227
+ flow_model: nn.Module,
228
+ device=torch.device("cuda"),
229
+ float8_dtype=torch.float8_e4m3fn,
230
+ input_float8_dtype=torch.float8_e5m2,
231
+ offload_flow=False,
232
+ ):
233
+ """
234
+ Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
235
+ """
236
+ for i, module in enumerate(flow_model.double_blocks):
237
+ module.to(device)
238
+ module.eval()
239
+ recursive_swap_linears(
240
+ module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
241
+ )
242
+ torch.cuda.empty_cache()
243
+ for i, module in enumerate(flow_model.single_blocks):
244
+ module.to(device)
245
+ module.eval()
246
+ recursive_swap_linears(
247
+ module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
248
+ )
249
+ torch.cuda.empty_cache()
250
+ to_gpu_extras = [
251
+ "vector_in",
252
+ "img_in",
253
+ "txt_in",
254
+ "time_in",
255
+ "guidance_in",
256
+ "final_layer",
257
+ "pe_embedder",
258
+ ]
259
+ for module in to_gpu_extras:
260
+ m_extra = getattr(flow_model, module)
261
+ if m_extra is None:
262
+ continue
263
+ m_extra.to(device)
264
+ m_extra.eval()
265
+ if isinstance(m_extra, nn.Linear) and not isinstance(
266
+ m_extra, (F8Linear, CublasLinear)
267
+ ):
268
+ setattr(
269
+ flow_model,
270
+ module,
271
+ F8Linear.from_linear(
272
+ m_extra,
273
+ float8_dtype=float8_dtype,
274
+ input_float8_dtype=input_float8_dtype,
275
+ ),
276
+ )
277
+ del m_extra
278
+ elif module != "final_layer":
279
+ recursive_swap_linears(
280
+ m_extra,
281
+ float8_dtype=float8_dtype,
282
+ input_float8_dtype=input_float8_dtype,
283
+ )
284
+ torch.cuda.empty_cache()
285
+ if offload_flow:
286
+ flow_model.to("cpu")
287
+ torch.cuda.empty_cache()
288
+ return flow_model
flux_pipeline.py CHANGED
@@ -1,13 +1,12 @@
1
- import base64
2
  import io
3
  import math
4
  from typing import TYPE_CHECKING, Callable, List
5
  from PIL import Image
6
- from einops import rearrange, repeat
7
  import numpy as np
8
 
9
  import torch
10
 
 
11
  from flux_emphasis import get_weighted_text_embeddings_flux
12
 
13
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -20,10 +19,9 @@ from torch._inductor import config as ind_config
20
  from pybase64 import standard_b64decode
21
 
22
  config.cache_size_limit = 10000000000
23
- ind_config.force_fuse_int_mm_with_mul = True
24
-
25
  from loguru import logger
26
- from turbojpeg_imgs import TurboImage
27
  from torchvision.transforms import functional as TF
28
  from tqdm import tqdm
29
  from util import (
@@ -50,7 +48,7 @@ class FluxPipeline:
50
  t5: "HFEmbedder" = None,
51
  model: "Flux" = None,
52
  ae: "AutoEncoder" = None,
53
- dtype: torch.dtype = torch.bfloat16,
54
  verbose: bool = False,
55
  flux_device: torch.device | str = "cuda:0",
56
  ae_device: torch.device | str = "cuda:1",
@@ -87,10 +85,42 @@ class FluxPipeline:
87
  self.model: "Flux" = model
88
  self.ae: "AutoEncoder" = ae
89
  self.rng = torch.Generator(device="cpu")
90
- self.turbojpeg = TurboImage()
91
  self.verbose = verbose
92
  self.ae_dtype = torch.bfloat16
93
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  @torch.inference_mode()
96
  def prepare(
@@ -126,6 +156,9 @@ class FluxPipeline:
126
  )
127
 
128
  img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
 
 
 
129
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
130
  self,
131
  prompt,
@@ -134,6 +167,10 @@ class FluxPipeline:
134
  target_device=target_device,
135
  target_dtype=target_dtype,
136
  )
 
 
 
 
137
  return img, img_ids, vec, txt, txt_ids
138
 
139
  @torch.inference_mode()
@@ -196,29 +233,39 @@ class FluxPipeline:
196
  @torch.inference_mode()
197
  def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
198
  # bring into PIL format and save
 
 
199
  x = x.clamp(-1, 1)
200
  num_images = x.shape[0]
201
  images: List[torch.Tensor] = []
202
  for i in range(num_images):
203
- x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
204
  images.append(x)
205
  if len(images) == 1:
206
  im = images[0]
207
  else:
208
  im = torch.vstack(images)
209
 
210
- im = self.turbojpeg.encode_torch(im, quality=95)
 
211
  images.clear()
212
  return io.BytesIO(im)
213
 
214
  @torch.inference_mode()
215
  def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
216
- x = x.to(self.device_ae)
 
 
 
 
217
  x = self.unpack(x.float(), height, width)
218
  with torch.autocast(
219
  device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
220
  ):
221
  x = self.ae.decode(x)
 
 
 
222
  return x
223
 
224
  def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
@@ -269,11 +316,16 @@ class FluxPipeline:
269
  dtype=torch.bfloat16,
270
  cache_enabled=False,
271
  ):
 
 
272
  init_image = (
273
  self.ae.encode(init_image)
274
  .to(dtype=self.dtype, device=self.device_flux)
275
  .repeat(num_images, 1, 1, 1)
276
  )
 
 
 
277
 
278
  x = self.get_noise(
279
  num_images,
@@ -338,11 +390,14 @@ class FluxPipeline:
338
  generator=generator,
339
  num_images=num_images,
340
  )
341
- img, img_ids, vec, txt, txt_ids = self.prepare(
342
- img=img,
343
- prompt=prompt,
344
- target_device=self.device_flux,
345
- target_dtype=self.dtype,
 
 
 
346
  )
347
 
348
  # this is ignored for schnell
@@ -350,6 +405,8 @@ class FluxPipeline:
350
  (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
351
  )
352
  t_vec = None
 
 
353
  for t_curr, t_prev in tqdm(
354
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
355
  ):
@@ -374,6 +431,8 @@ class FluxPipeline:
374
 
375
  img = img + (t_prev - t_curr) * pred
376
 
 
 
377
  torch.cuda.empty_cache()
378
 
379
  # decode latents to pixel space
@@ -384,37 +443,35 @@ class FluxPipeline:
384
  return self.into_bytes(img)
385
 
386
  @classmethod
387
- def load_pipeline_from_config_path(cls, path: str) -> "FluxPipeline":
 
 
388
  with torch.inference_mode():
389
  config = load_config_from_path(path)
 
 
390
  return cls.load_pipeline_from_config(config)
391
 
392
  @classmethod
393
  def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline":
394
- from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
395
 
396
  with torch.inference_mode():
397
  print("flow_quantization_dtype", config.flow_quantization_dtype)
398
 
399
  models = load_models_from_config(config)
400
  config = models.config
401
- num_layers_to_quantize = config.num_to_quant
402
  flux_device = into_device(config.flux_device)
403
  ae_device = into_device(config.ae_device)
404
  clip_device = into_device(config.text_enc_device)
405
  t5_device = into_device(config.text_enc_device)
406
  flux_dtype = into_dtype(config.flow_dtype)
407
- flow_model = models.flow
408
-
409
- flow_model = quantize_and_dispatch_to_device(
410
- flow_model=flow_model,
411
- flux_device=flux_device,
412
- flux_dtype=flux_dtype,
413
- num_layers_to_quantize=num_layers_to_quantize,
414
- compile_extras=config.compile_extras,
415
- compile_blocks=config.compile_blocks,
416
- quantize_extras=config.quantize_extras,
417
- quantization_dtype=config.flow_quantization_dtype,
418
  )
419
 
420
  return cls(
@@ -435,29 +492,24 @@ class FluxPipeline:
435
 
436
  if __name__ == "__main__":
437
  pipe = FluxPipeline.load_pipeline_from_config_path(
438
- "configs/config-dev-gigaquant.json"
439
  )
440
  o = pipe.generate(
441
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
442
  height=1024,
443
- width=1024,
444
  num_steps=24,
445
- guidance=3.0,
 
446
  )
447
  open("out.jpg", "wb").write(o.read())
448
- o = pipe.generate(
449
- prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
450
- height=1024,
451
- width=1024,
452
- num_steps=24,
453
- guidance=3.0,
454
- )
455
- open("out2.jpg", "wb").write(o.read())
456
- o = pipe.generate(
457
- prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
458
- height=1024,
459
- width=1024,
460
- num_steps=24,
461
- guidance=3.0,
462
- )
463
- open("out3.jpg", "wb").write(o.read())
 
 
1
  import io
2
  import math
3
  from typing import TYPE_CHECKING, Callable, List
4
  from PIL import Image
 
5
  import numpy as np
6
 
7
  import torch
8
 
9
+ from einops import rearrange
10
  from flux_emphasis import get_weighted_text_embeddings_flux
11
 
12
  torch.backends.cuda.matmul.allow_tf32 = True
 
19
  from pybase64 import standard_b64decode
20
 
21
  config.cache_size_limit = 10000000000
22
+ ind_config.shape_padding = True
 
23
  from loguru import logger
24
+ from image_encoder import ImageEncoder
25
  from torchvision.transforms import functional as TF
26
  from tqdm import tqdm
27
  from util import (
 
48
  t5: "HFEmbedder" = None,
49
  model: "Flux" = None,
50
  ae: "AutoEncoder" = None,
51
+ dtype: torch.dtype = torch.float16,
52
  verbose: bool = False,
53
  flux_device: torch.device | str = "cuda:0",
54
  ae_device: torch.device | str = "cuda:1",
 
85
  self.model: "Flux" = model
86
  self.ae: "AutoEncoder" = ae
87
  self.rng = torch.Generator(device="cpu")
88
+ self.img_encoder = ImageEncoder()
89
  self.verbose = verbose
90
  self.ae_dtype = torch.bfloat16
91
  self.config = config
92
+ self.offload_text_encoder = config.offload_text_encoder
93
+ self.offload_vae = config.offload_vae
94
+ self.offload_flow = config.offload_flow
95
+
96
+ if self.config.compile_blocks or self.config.compile_extras:
97
+ print("Warmups for compile...")
98
+ warmup_dict = dict(
99
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
100
+ height=1024,
101
+ width=1024,
102
+ num_steps=30,
103
+ guidance=3.5,
104
+ seed=10,
105
+ )
106
+ self.generate(**warmup_dict)
107
+ to_gpu_extras = [
108
+ "vector_in",
109
+ "img_in",
110
+ "txt_in",
111
+ "time_in",
112
+ "guidance_in",
113
+ "final_layer",
114
+ "pe_embedder",
115
+ ]
116
+ if self.config.compile_blocks:
117
+ for block in self.model.double_blocks:
118
+ block.compile()
119
+ for block in self.model.single_blocks:
120
+ block.compile()
121
+ if self.config.compile_extras:
122
+ for extra in to_gpu_extras:
123
+ getattr(self.model, extra).compile()
124
 
125
  @torch.inference_mode()
126
  def prepare(
 
156
  )
157
 
158
  img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
159
+ if self.offload_text_encoder:
160
+ self.clip.to(self.device_clip)
161
+ self.t5.to(self.device_t5)
162
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
163
  self,
164
  prompt,
 
167
  target_device=target_device,
168
  target_dtype=target_dtype,
169
  )
170
+ if self.offload_text_encoder:
171
+ self.clip.to("cpu")
172
+ self.t5.to("cpu")
173
+ torch.cuda.empty_cache()
174
  return img, img_ids, vec, txt, txt_ids
175
 
176
  @torch.inference_mode()
 
233
  @torch.inference_mode()
234
  def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
235
  # bring into PIL format and save
236
+ torch.cuda.synchronize()
237
+ x = x.contiguous()
238
  x = x.clamp(-1, 1)
239
  num_images = x.shape[0]
240
  images: List[torch.Tensor] = []
241
  for i in range(num_images):
242
+ x = x[i].add(1.0).mul(127.5).clamp(0, 255).contiguous().type(torch.uint8)
243
  images.append(x)
244
  if len(images) == 1:
245
  im = images[0]
246
  else:
247
  im = torch.vstack(images)
248
 
249
+ torch.cuda.synchronize()
250
+ im = self.turbojpeg.encode_torch(im, quality=99)
251
  images.clear()
252
  return io.BytesIO(im)
253
 
254
  @torch.inference_mode()
255
  def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
256
+ if self.offload_vae:
257
+ self.ae.to(self.device_ae)
258
+ x = x.to(self.device_ae)
259
+ else:
260
+ x = x.to(self.device_ae)
261
  x = self.unpack(x.float(), height, width)
262
  with torch.autocast(
263
  device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
264
  ):
265
  x = self.ae.decode(x)
266
+ if self.offload_vae:
267
+ self.ae.to("cpu")
268
+ torch.cuda.empty_cache()
269
  return x
270
 
271
  def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
 
316
  dtype=torch.bfloat16,
317
  cache_enabled=False,
318
  ):
319
+ if self.offload_vae:
320
+ self.ae.to(self.device_ae)
321
  init_image = (
322
  self.ae.encode(init_image)
323
  .to(dtype=self.dtype, device=self.device_flux)
324
  .repeat(num_images, 1, 1, 1)
325
  )
326
+ if self.offload_vae:
327
+ self.ae.to("cpu")
328
+ torch.cuda.empty_cache()
329
 
330
  x = self.get_noise(
331
  num_images,
 
390
  generator=generator,
391
  num_images=num_images,
392
  )
393
+ img, img_ids, vec, txt, txt_ids = map(
394
+ lambda x: x.contiguous(),
395
+ self.prepare(
396
+ img=img,
397
+ prompt=prompt,
398
+ target_device=self.device_flux,
399
+ target_dtype=self.dtype,
400
+ ),
401
  )
402
 
403
  # this is ignored for schnell
 
405
  (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
406
  )
407
  t_vec = None
408
+ if self.offload_flow:
409
+ self.model.to(self.device_flux)
410
  for t_curr, t_prev in tqdm(
411
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
412
  ):
 
431
 
432
  img = img + (t_prev - t_curr) * pred
433
 
434
+ if self.offload_flow:
435
+ self.model.to("cpu")
436
  torch.cuda.empty_cache()
437
 
438
  # decode latents to pixel space
 
443
  return self.into_bytes(img)
444
 
445
  @classmethod
446
+ def load_pipeline_from_config_path(
447
+ cls, path: str, flow_model_path: str = None
448
+ ) -> "FluxPipeline":
449
  with torch.inference_mode():
450
  config = load_config_from_path(path)
451
+ if flow_model_path:
452
+ config.ckpt_path = flow_model_path
453
  return cls.load_pipeline_from_config(config)
454
 
455
  @classmethod
456
  def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline":
457
+ from float8_quantize import quantize_flow_transformer_and_dispatch_float8
458
 
459
  with torch.inference_mode():
460
  print("flow_quantization_dtype", config.flow_quantization_dtype)
461
 
462
  models = load_models_from_config(config)
463
  config = models.config
 
464
  flux_device = into_device(config.flux_device)
465
  ae_device = into_device(config.ae_device)
466
  clip_device = into_device(config.text_enc_device)
467
  t5_device = into_device(config.text_enc_device)
468
  flux_dtype = into_dtype(config.flow_dtype)
469
+ flow_model = models.flow.type(flux_dtype).to(
470
+ memory_format=torch.channels_last
471
+ )
472
+
473
+ flow_model = quantize_flow_transformer_and_dispatch_float8(
474
+ flow_model, flux_device
 
 
 
 
 
475
  )
476
 
477
  return cls(
 
492
 
493
  if __name__ == "__main__":
494
  pipe = FluxPipeline.load_pipeline_from_config_path(
495
+ "configs/config-dev-offload.json"
496
  )
497
  o = pipe.generate(
498
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
499
  height=1024,
500
+ width=576,
501
  num_steps=24,
502
+ guidance=3.5,
503
+ seed=10,
504
  )
505
  open("out.jpg", "wb").write(o.read())
506
+ for x in range(10):
507
+
508
+ o = pipe.generate(
509
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
510
+ height=1024,
511
+ width=576,
512
+ num_steps=24,
513
+ guidance=3.5,
514
+ )
515
+ open(f"out{x}.jpg", "wb").write(o.read())
 
 
 
 
 
 
image_encoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class ImageEncoder:
8
+
9
+ @torch.inference_mode()
10
+ def encode_torch(self, img: torch.Tensor, quality=90):
11
+ if img.ndim == 2:
12
+ img = (
13
+ img[None]
14
+ .contiguous()
15
+ .repeat_interleave(3, dim=0)
16
+ .contiguous()
17
+ .clamp(0, 255)
18
+ .type(torch.uint8)
19
+ )
20
+ print(img.shape)
21
+ elif img.ndim == 3:
22
+ if img.shape[0] == 3:
23
+ img = img.contiguous().clamp(0, 255).type(torch.uint8)
24
+
25
+ elif img.shape[2] == 3:
26
+ img = img.permute(2, 0, 1).contiguous().clamp(0, 255).type(torch.uint8)
27
+ else:
28
+ raise ValueError(f"Unsupported image shape: {img.shape}")
29
+ else:
30
+ raise ValueError(f"Unsupported image num dims: {img.ndim}")
31
+
32
+ img = (
33
+ img.permute(1, 2, 0)
34
+ .contiguous()
35
+ .to(torch.uint8)
36
+ .cpu()
37
+ .numpy()
38
+ .astype(np.uint8)
39
+ )
40
+ im = Image.fromarray(img)
41
+ iob = io.BytesIO()
42
+ im.save(iob, format="JPEG", quality=95)
43
+ iob.seek(0)
44
+ return iob.getvalue()
45
+
46
+
47
+ def test_real_img():
48
+ from PIL import Image
49
+ import numpy as np
50
+
51
+ im = "out.jpg"
52
+ im = Image.open(im)
53
+ im = np.array(im)
54
+ img_hwc = torch.from_numpy(im).cuda().type(torch.float32)
55
+ img_chw = img_hwc.permute(2, 0, 1).contiguous()
56
+ img_gray = img_hwc.mean(dim=2, keepdim=False).contiguous().clamp(0, 255)
57
+ tj = TurboImage()
58
+ o = tj.encode_torch(img_chw)
59
+ o2 = tj.encode_torch(img_hwc)
60
+ o3 = tj.encode_torch(img_gray)
61
+ with open("out_chw.jpg", "wb") as f:
62
+ f.write(o2)
63
+ with open("out_hwc.jpg", "wb") as f:
64
+ f.write(o)
65
+ with open("out_gray.jpg", "wb") as f:
66
+ f.write(o3)
67
+ # print(o)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ test_real_img()
main.py CHANGED
@@ -87,7 +87,9 @@ def main():
87
  args = parse_args()
88
 
89
  if args.config_path:
90
- app.state.model = FluxPipeline.load_pipeline_from_config_path(args.config_path)
 
 
91
  else:
92
  model_version = (
93
  ModelVersion.flux_dev
 
87
  args = parse_args()
88
 
89
  if args.config_path:
90
+ app.state.model = FluxPipeline.load_pipeline_from_config_path(
91
+ args.config_path, flow_model_path=args.flow_model_path
92
+ )
93
  else:
94
  model_version = (
95
  ModelVersion.flux_dev
modules/conditioner.py CHANGED
@@ -1,10 +1,6 @@
1
  import os
2
 
3
  import torch
4
- from pydash import max_
5
- from quanto import freeze, qfloat8, qint2, qint4, qint8, quantize
6
- from quanto.nn.qmodule import _QMODULE_TABLE
7
- from safetensors.torch import load_file, load_model, save_model
8
  from torch import Tensor, nn
9
  from transformers import (
10
  CLIPTextModel,
@@ -13,7 +9,7 @@ from transformers import (
13
  T5Tokenizer,
14
  __version__,
15
  )
16
- from transformers.utils.quantization_config import QuantoConfig
17
 
18
  CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
19
 
@@ -31,6 +27,25 @@ def into_quantization_name(quantization_dtype: str) -> str:
31
  raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class HFEmbedder(nn.Module):
35
  def __init__(
36
  self,
@@ -38,15 +53,21 @@ class HFEmbedder(nn.Module):
38
  max_length: int,
39
  device: torch.device | int,
40
  quantization_dtype: str | None = None,
 
41
  **hf_kwargs,
42
  ):
43
  super().__init__()
 
 
 
 
 
 
 
 
44
  self.is_clip = version.startswith("openai")
45
  self.max_length = max_length
46
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
47
- quant_name = (
48
- into_quantization_name(quantization_dtype) if quantization_dtype else None
49
- )
50
 
51
  if self.is_clip:
52
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
@@ -57,13 +78,10 @@ class HFEmbedder(nn.Module):
57
  version,
58
  **hf_kwargs,
59
  quantization_config=(
60
- QuantoConfig(
61
- weights=quant_name,
62
- )
63
- if quant_name
64
  else None
65
  ),
66
- device_map={"": device},
67
  )
68
 
69
  else:
@@ -72,17 +90,21 @@ class HFEmbedder(nn.Module):
72
  )
73
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
74
  version,
75
- device_map={"": device},
76
  **hf_kwargs,
77
  quantization_config=(
78
- QuantoConfig(
79
- weights=quant_name,
80
- )
81
- if quant_name
82
  else None
83
  ),
84
  )
85
 
 
 
 
 
 
 
 
86
  def forward(self, text: list[str]) -> Tensor:
87
  batch_encoding = self.tokenizer(
88
  text,
 
1
  import os
2
 
3
  import torch
 
 
 
 
4
  from torch import Tensor, nn
5
  from transformers import (
6
  CLIPTextModel,
 
9
  T5Tokenizer,
10
  __version__,
11
  )
12
+ from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig
13
 
14
  CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
15
 
 
27
  raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
28
 
29
 
30
+ def auto_quantization_config(
31
+ quantization_dtype: str,
32
+ ) -> QuantoConfig | BitsAndBytesConfig:
33
+ if quantization_dtype == "qfloat8":
34
+ return QuantoConfig(weights="float8")
35
+ elif quantization_dtype == "qint4":
36
+ return BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.bfloat16,
39
+ bnb_4bit_quant_type="nf4",
40
+ )
41
+ elif quantization_dtype == "qint8":
42
+ return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
43
+ elif quantization_dtype == "qint2":
44
+ return QuantoConfig(weights="int2")
45
+ else:
46
+ raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
47
+
48
+
49
  class HFEmbedder(nn.Module):
50
  def __init__(
51
  self,
 
53
  max_length: int,
54
  device: torch.device | int,
55
  quantization_dtype: str | None = None,
56
+ offloading_device: torch.device | int | None = torch.device("cpu"),
57
  **hf_kwargs,
58
  ):
59
  super().__init__()
60
+ self.offloading_device = (
61
+ offloading_device
62
+ if isinstance(offloading_device, torch.device)
63
+ else torch.device(offloading_device)
64
+ )
65
+ self.device = (
66
+ device if isinstance(device, torch.device) else torch.device(device)
67
+ )
68
  self.is_clip = version.startswith("openai")
69
  self.max_length = max_length
70
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
 
 
 
71
 
72
  if self.is_clip:
73
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
 
78
  version,
79
  **hf_kwargs,
80
  quantization_config=(
81
+ auto_quantization_config(quantization_dtype)
82
+ if quantization_dtype
 
 
83
  else None
84
  ),
 
85
  )
86
 
87
  else:
 
90
  )
91
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
92
  version,
 
93
  **hf_kwargs,
94
  quantization_config=(
95
+ auto_quantization_config(quantization_dtype)
96
+ if quantization_dtype
 
 
97
  else None
98
  ),
99
  )
100
 
101
+ def offload(self):
102
+ self.hf_module.to(device=self.offloading_device)
103
+ torch.cuda.empty_cache()
104
+
105
+ def cuda(self):
106
+ self.hf_module.to(device=self.device)
107
+
108
  def forward(self, text: list[str]) -> Tensor:
109
  batch_encoding = self.tokenizer(
110
  text,
modules/flux_model.py CHANGED
@@ -11,14 +11,13 @@ torch.set_float32_matmul_precision("high")
11
  import math
12
 
13
  from torch import Tensor, nn
14
- from torch._dynamo import config
15
- from torch._inductor import config as ind_config
16
  from pydantic import BaseModel
17
  from torch.nn import functional as F
18
 
19
- config.cache_size_limit = 10000000000
20
- ind_config.compile_threads = os.cpu_count()
21
- ind_config.shape_padding = True
 
22
 
23
 
24
  class FluxParams(BaseModel):
@@ -37,7 +36,7 @@ class FluxParams(BaseModel):
37
 
38
 
39
  # attention is always same shape each time it's called per H*W, so compile with fullgraph
40
- @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
41
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
42
  q, k = apply_rope(q, k, pe)
43
  x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
@@ -45,7 +44,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
45
  return x
46
 
47
 
48
- @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
49
  def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
  scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
  omega = 1.0 / (theta**scale)
@@ -202,8 +201,7 @@ class DoubleStreamBlock(nn.Module):
202
  num_heads: int,
203
  mlp_ratio: float,
204
  qkv_bias: bool = False,
205
- dtype: torch.dtype = torch.bfloat16,
206
- idx: int = 0,
207
  ):
208
  super().__init__()
209
  self.dtype = dtype
@@ -232,9 +230,9 @@ class DoubleStreamBlock(nn.Module):
232
 
233
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
234
  self.txt_mlp = nn.Sequential(
235
- (nn.Linear(hidden_size, mlp_hidden_dim, bias=True)),
236
  nn.GELU(approximate="tanh"),
237
- (nn.Linear(mlp_hidden_dim, hidden_size, bias=True)),
238
  )
239
  self.K = 3
240
  self.H = self.num_heads
@@ -279,13 +277,13 @@ class DoubleStreamBlock(nn.Module):
279
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
280
  img = img + img_mod2.gate * self.img_mlp(
281
  (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
282
- ).clamp(min=-384, max=384)
283
 
284
  # calculate the txt bloks
285
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
286
  txt = txt + txt_mod2.gate * self.txt_mlp(
287
  (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
288
- ).clamp(min=-384, max=384)
289
 
290
  return img, txt
291
 
@@ -302,7 +300,7 @@ class SingleStreamBlock(nn.Module):
302
  num_heads: int,
303
  mlp_ratio: float = 4.0,
304
  qk_scale: float | None = None,
305
- dtype: torch.dtype = torch.bfloat16,
306
  ):
307
  super().__init__()
308
  self.dtype = dtype
@@ -343,7 +341,7 @@ class SingleStreamBlock(nn.Module):
343
  q, k = self.norm(q, k, v)
344
  attn = attention(q, k, v, pe=pe)
345
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
346
- min=-384, max=384
347
  )
348
  return x + mod.gate * output
349
 
@@ -352,11 +350,11 @@ class LastLayer(nn.Module):
352
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
353
  super().__init__()
354
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
355
- self.linear = nn.Linear(
356
  hidden_size, patch_size * patch_size * out_channels, bias=True
357
  )
358
  self.adaLN_modulation = nn.Sequential(
359
- nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
360
  )
361
 
362
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
@@ -413,9 +411,8 @@ class Flux(nn.Module):
413
  mlp_ratio=params.mlp_ratio,
414
  qkv_bias=params.qkv_bias,
415
  dtype=self.dtype,
416
- idx=idx,
417
  )
418
- for idx in range(params.depth)
419
  ]
420
  )
421
 
 
11
  import math
12
 
13
  from torch import Tensor, nn
 
 
14
  from pydantic import BaseModel
15
  from torch.nn import functional as F
16
 
17
+ try:
18
+ from cublas_ops import CublasLinear
19
+ except ImportError:
20
+ CublasLinear = nn.Linear
21
 
22
 
23
  class FluxParams(BaseModel):
 
36
 
37
 
38
  # attention is always same shape each time it's called per H*W, so compile with fullgraph
39
+ # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
40
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
41
  q, k = apply_rope(q, k, pe)
42
  x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
 
44
  return x
45
 
46
 
47
+ # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
48
  def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
49
  scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
50
  omega = 1.0 / (theta**scale)
 
201
  num_heads: int,
202
  mlp_ratio: float,
203
  qkv_bias: bool = False,
204
+ dtype: torch.dtype = torch.float16,
 
205
  ):
206
  super().__init__()
207
  self.dtype = dtype
 
230
 
231
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
232
  self.txt_mlp = nn.Sequential(
233
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
234
  nn.GELU(approximate="tanh"),
235
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
236
  )
237
  self.K = 3
238
  self.H = self.num_heads
 
277
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
278
  img = img + img_mod2.gate * self.img_mlp(
279
  (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
280
+ ).clamp(min=-384 * 2, max=384 * 2)
281
 
282
  # calculate the txt bloks
283
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
284
  txt = txt + txt_mod2.gate * self.txt_mlp(
285
  (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
286
+ ).clamp(min=-384 * 2, max=384 * 2)
287
 
288
  return img, txt
289
 
 
300
  num_heads: int,
301
  mlp_ratio: float = 4.0,
302
  qk_scale: float | None = None,
303
+ dtype: torch.dtype = torch.float16,
304
  ):
305
  super().__init__()
306
  self.dtype = dtype
 
341
  q, k = self.norm(q, k, v)
342
  attn = attention(q, k, v, pe=pe)
343
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
344
+ min=-384 * 4, max=384 * 4
345
  )
346
  return x + mod.gate * output
347
 
 
350
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
351
  super().__init__()
352
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
353
+ self.linear = CublasLinear(
354
  hidden_size, patch_size * patch_size * out_channels, bias=True
355
  )
356
  self.adaLN_modulation = nn.Sequential(
357
+ nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
358
  )
359
 
360
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
 
411
  mlp_ratio=params.mlp_ratio,
412
  qkv_bias=params.qkv_bias,
413
  dtype=self.dtype,
 
414
  )
415
+ for _ in range(params.depth)
416
  ]
417
  )
418
 
turbojpeg_imgs.py DELETED
@@ -1,134 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from turbojpeg import (
4
- TurboJPEG,
5
- TJPF_GRAY,
6
- TJFLAG_PROGRESSIVE,
7
- TJFLAG_FASTUPSAMPLE,
8
- TJFLAG_FASTDCT,
9
- TJPF_RGB,
10
- TJPF_BGR,
11
- TJSAMP_GRAY,
12
- TJSAMP_411,
13
- TJSAMP_420,
14
- TJSAMP_422,
15
- TJSAMP_444,
16
- TJSAMP_440,
17
- TJSAMP_441,
18
- )
19
-
20
-
21
- class Subsampling:
22
- S411 = TJSAMP_411
23
- S420 = TJSAMP_420
24
- S422 = TJSAMP_422
25
- S444 = TJSAMP_444
26
- S440 = TJSAMP_440
27
- S441 = TJSAMP_441
28
- GRAY = TJSAMP_GRAY
29
-
30
-
31
- class Flags:
32
- PROGRESSIVE = TJFLAG_PROGRESSIVE
33
- FASTUPSAMPLE = TJFLAG_FASTUPSAMPLE
34
- FASTDCT = TJFLAG_FASTDCT
35
-
36
-
37
- class PixelFormat:
38
- GRAY = TJPF_GRAY
39
- RGB = TJPF_RGB
40
- BGR = TJPF_BGR
41
-
42
-
43
- class TurboImage:
44
- def __init__(self):
45
- self.tj = TurboJPEG()
46
- self.flags = Flags.PROGRESSIVE
47
-
48
- self.subsampling_gray = Subsampling.GRAY
49
- self.pixel_format_gray = PixelFormat.GRAY
50
- self.subsampling_rgb = Subsampling.S420
51
- self.pixel_format_rgb = PixelFormat.RGB
52
-
53
- def set_subsampling_gray(self, subsampling):
54
- self.subsampling_gray = subsampling
55
-
56
- def set_subsampling_rgb(self, subsampling):
57
- self.subsampling_rgb = subsampling
58
-
59
- def set_pixel_format_gray(self, pixel_format):
60
- self.pixel_format_gray = pixel_format
61
-
62
- def set_pixel_format_rgb(self, pixel_format):
63
- self.pixel_format_rgb = pixel_format
64
-
65
- def set_flags(self, flags):
66
- self.flags = flags
67
-
68
- def encode(
69
- self,
70
- img,
71
- subsampling,
72
- pixel_format,
73
- quality=90,
74
- ):
75
- return self.tj.encode(
76
- img,
77
- quality=quality,
78
- flags=self.flags,
79
- pixel_format=pixel_format,
80
- jpeg_subsample=subsampling,
81
- )
82
-
83
- @torch.inference_mode()
84
- def encode_torch(self, img: torch.Tensor, quality=90):
85
- if img.ndim == 2:
86
- subsampling = self.subsampling_gray
87
- pixel_format = self.pixel_format_gray
88
- img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
89
- elif img.ndim == 3:
90
- subsampling = self.subsampling_rgb
91
- pixel_format = self.pixel_format_rgb
92
- if img.shape[0] == 3:
93
- img = (
94
- img.permute(1, 2, 0)
95
- .clamp(0, 255)
96
- .cpu()
97
- .contiguous()
98
- .numpy()
99
- .astype(np.uint8)
100
- )
101
- elif img.shape[2] == 3:
102
- img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
103
- else:
104
- raise ValueError(f"Unsupported image shape: {img.shape}")
105
- else:
106
- raise ValueError(f"Unsupported image num dims: {img.ndim}")
107
-
108
- return self.encode(
109
- img,
110
- quality=quality,
111
- subsampling=subsampling,
112
- pixel_format=pixel_format,
113
- )
114
-
115
- def encode_numpy(self, img: np.ndarray, quality=90):
116
- if img.ndim == 2:
117
- subsampling = self.subsampling_gray
118
- pixel_format = self.pixel_format_gray
119
- elif img.ndim == 3:
120
- if img.shape[0] == 3:
121
- img = np.ascontiguousarray(img.transpose(1, 2, 0))
122
- elif img.shape[2] == 3:
123
- img = np.ascontiguousarray(img)
124
- else:
125
- raise ValueError(f"Unsupported image shape: {img.shape}")
126
- subsampling = self.subsampling_rgb
127
- pixel_format = self.pixel_format_rgb
128
- else:
129
- raise ValueError(f"Unsupported image num dims: {img.ndim}")
130
-
131
- img = img.clip(0, 255).astype(np.uint8)
132
- return self.encode(
133
- img, quality=quality, subsampling=subsampling, pixel_format=pixel_format
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -50,6 +50,9 @@ class ModelSpec(BaseModel):
50
  text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
51
  ae_quantization_dtype: Optional[QuantizationDtype] = None
52
  clip_quantization_dtype: Optional[QuantizationDtype] = None
 
 
 
53
 
54
  model_config: ConfigDict = {
55
  "arbitrary_types_allowed": True,
@@ -242,6 +245,9 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
242
  current_quants=0,
243
  quantization_dtype=into_qtype(config.ae_quantization_dtype),
244
  )
 
 
 
245
  return ae
246
 
247
 
 
50
  text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
51
  ae_quantization_dtype: Optional[QuantizationDtype] = None
52
  clip_quantization_dtype: Optional[QuantizationDtype] = None
53
+ offload_text_encoder: bool = False
54
+ offload_vae: bool = False
55
+ offload_flow: bool = False
56
 
57
  model_config: ConfigDict = {
58
  "arbitrary_types_allowed": True,
 
245
  current_quants=0,
246
  quantization_dtype=into_qtype(config.ae_quantization_dtype),
247
  )
248
+ if config.offload_vae:
249
+ ae.to("cpu")
250
+ torch.cuda.empty_cache()
251
  return ae
252
 
253