Safetensors
aredden commited on
Commit
c4a514f
·
1 Parent(s): c2ecfb5
configs/config-dev-gigaquant.json CHANGED
@@ -41,12 +41,18 @@
41
  "repo_ae": "ae.sft",
42
  "text_enc_max_length": 512,
43
  "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
- "text_enc_device": "cuda:1",
45
- "ae_device": "cuda:1",
46
  "flux_device": "cuda:0",
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
50
- "num_to_quant": 8000,
 
 
 
 
 
 
51
  "quantize_extras": true
52
  }
 
41
  "repo_ae": "ae.sft",
42
  "text_enc_max_length": 512,
43
  "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
  "flux_device": "cuda:0",
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 220,
51
+ "flow_quantization_dtype": "qint4",
52
+ "text_enc_quantization_dtype": "qint4",
53
+ "ae_quantization_dtype": "qint4",
54
+ "clip_quantization_dtype": "qint4",
55
+ "compile_extras": false,
56
+ "compile_blocks": false,
57
  "quantize_extras": true
58
  }
configs/config-dev.json CHANGED
@@ -47,6 +47,8 @@
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
 
 
50
  "num_to_quant": 22,
51
  "compile_extras": false,
52
  "compile_blocks": false
 
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
  "num_to_quant": 22,
53
  "compile_extras": false,
54
  "compile_blocks": false
flux_pipeline.py CHANGED
@@ -394,6 +394,7 @@ class FluxPipeline:
394
  from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
395
 
396
  with torch.inference_mode():
 
397
 
398
  models = load_models_from_config(config)
399
  config = models.config
@@ -413,6 +414,7 @@ class FluxPipeline:
413
  compile_extras=config.compile_extras,
414
  compile_blocks=config.compile_blocks,
415
  quantize_extras=config.quantize_extras,
 
416
  )
417
 
418
  return cls(
 
394
  from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
395
 
396
  with torch.inference_mode():
397
+ print("flow_quantization_dtype", config.flow_quantization_dtype)
398
 
399
  models = load_models_from_config(config)
400
  config = models.config
 
414
  compile_extras=config.compile_extras,
415
  compile_blocks=config.compile_blocks,
416
  quantize_extras=config.quantize_extras,
417
+ quantization_dtype=config.flow_quantization_dtype,
418
  )
419
 
420
  return cls(
modules/conditioner.py CHANGED
@@ -1,37 +1,85 @@
1
- from torch import Tensor, nn
 
2
  import torch
3
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- from transformers.utils.quantization_config import BitsAndBytesConfig, QuantoConfig
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class HFEmbedder(nn.Module):
9
  def __init__(
10
- self, version: str, max_length: int, device: torch.device | int, **hf_kwargs
 
 
 
 
 
11
  ):
12
  super().__init__()
13
  self.is_clip = version.startswith("openai")
14
  self.max_length = max_length
15
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
 
 
 
16
 
17
  if self.is_clip:
18
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
19
  version, max_length=max_length
20
  )
 
21
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
22
- version, **hf_kwargs
 
 
 
 
 
 
 
 
 
23
  )
24
- self.hf_module = self.hf_module.eval().requires_grad_(False).to(device)
25
  else:
26
  self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
27
  version, max_length=max_length
28
  )
29
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
30
  version,
31
- **hf_kwargs,
32
  device_map={"": device},
33
- quantization_config=QuantoConfig(
34
- weights="float8",
 
 
 
 
 
35
  ),
36
  )
37
 
@@ -51,3 +99,14 @@ class HFEmbedder(nn.Module):
51
  output_hidden_states=False,
52
  )
53
  return outputs[self.output_key]
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import torch
4
+ from pydash import max_
5
+ from quanto import freeze, qfloat8, qint2, qint4, qint8, quantize
6
+ from quanto.nn.qmodule import _QMODULE_TABLE
7
+ from safetensors.torch import load_file, load_model, save_model
8
+ from torch import Tensor, nn
9
+ from transformers import (
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ T5EncoderModel,
13
+ T5Tokenizer,
14
+ __version__,
15
+ )
16
+ from transformers.utils.quantization_config import QuantoConfig
17
+
18
+ CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
19
 
20
+
21
+ def into_quantization_name(quantization_dtype: str) -> str:
22
+ if quantization_dtype == "qfloat8":
23
+ return "float8"
24
+ elif quantization_dtype == "qint4":
25
+ return "int4"
26
+ elif quantization_dtype == "qint8":
27
+ return "int8"
28
+ elif quantization_dtype == "qint2":
29
+ return "int2"
30
+ else:
31
+ raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
32
 
33
 
34
  class HFEmbedder(nn.Module):
35
  def __init__(
36
+ self,
37
+ version: str,
38
+ max_length: int,
39
+ device: torch.device | int,
40
+ quantization_dtype: str | None = None,
41
+ **hf_kwargs,
42
  ):
43
  super().__init__()
44
  self.is_clip = version.startswith("openai")
45
  self.max_length = max_length
46
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
47
+ quant_name = (
48
+ into_quantization_name(quantization_dtype) if quantization_dtype else None
49
+ )
50
 
51
  if self.is_clip:
52
  self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
53
  version, max_length=max_length
54
  )
55
+
56
  self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
57
+ version,
58
+ **hf_kwargs,
59
+ quantization_config=(
60
+ QuantoConfig(
61
+ weights=quant_name,
62
+ )
63
+ if quant_name
64
+ else None
65
+ ),
66
+ device_map={"": device},
67
  )
68
+
69
  else:
70
  self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
71
  version, max_length=max_length
72
  )
73
  self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
74
  version,
 
75
  device_map={"": device},
76
+ **hf_kwargs,
77
+ quantization_config=(
78
+ QuantoConfig(
79
+ weights=quant_name,
80
+ )
81
+ if quant_name
82
+ else None
83
  ),
84
  )
85
 
 
99
  output_hidden_states=False,
100
  )
101
  return outputs[self.output_key]
102
+
103
+
104
+ if __name__ == "__main__":
105
+ model = HFEmbedder(
106
+ "city96/t5-v1_1-xxl-encoder-bf16",
107
+ max_length=512,
108
+ device=0,
109
+ quantization_dtype="qfloat8",
110
+ )
111
+ o = model(["hello"])
112
+ print(o)
modules/flux_model.py CHANGED
@@ -13,7 +13,6 @@ import math
13
  from torch import Tensor, nn
14
  from torch._dynamo import config
15
  from torch._inductor import config as ind_config
16
- from xformers.ops import memory_efficient_attention_forward
17
  from pydantic import BaseModel
18
  from torch.nn import functional as F
19
 
 
13
  from torch import Tensor, nn
14
  from torch._dynamo import config
15
  from torch._inductor import config as ind_config
 
16
  from pydantic import BaseModel
17
  from torch.nn import functional as F
18
 
quantize_swap_and_dispatch.py CHANGED
@@ -5,11 +5,35 @@ import torch
5
  from click import secho
6
  from cublas_ops import CublasLinear
7
 
8
- from quanto.nn import QModuleMixin, quantize_module, QLinear, QConv2d, QLayerNorm
9
- from quanto.tensor import Optimizer, qtype, qfloat8
 
 
 
 
 
 
10
  from torch import nn
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def _set_module_by_name(parent_module, name, child_module):
14
  module_names = name.split(".")
15
  if len(module_names) == 1:
@@ -121,7 +145,12 @@ def _is_block_compilable(module: nn.Module) -> bool:
121
 
122
  def _simple_swap_linears(model: nn.Module, root_name: str = ""):
123
  for name, module in model.named_children():
124
- if _is_linear(module):
 
 
 
 
 
125
  weights = module.weight.data
126
  bias = None
127
  if module.bias is not None:
@@ -155,7 +184,7 @@ def _full_quant(
155
  if current_quants < max_quants:
156
  current_quants += _quantize(model, quantization_dtype)
157
  _freeze(model)
158
- print(f"Quantized {current_quants} modules")
159
  return current_quants
160
 
161
 
@@ -174,11 +203,13 @@ def quantize_and_dispatch_to_device(
174
  flux_device: torch.device = torch.device("cuda"),
175
  flux_dtype: torch.dtype = torch.float16,
176
  num_layers_to_quantize: int = 20,
177
- quantization_dtype: qtype = qfloat8,
178
  compile_blocks: bool = True,
179
  compile_extras: bool = True,
180
  quantize_extras: bool = False,
 
181
  ):
 
182
  num_quanted = 0
183
  flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
184
  for block in flow_model.single_blocks:
@@ -188,7 +219,7 @@ def quantize_and_dispatch_to_device(
188
  block,
189
  num_layers_to_quantize,
190
  num_quanted,
191
- quantization_dtype=quantization_dtype,
192
  )
193
 
194
  for block in flow_model.double_blocks:
@@ -198,7 +229,7 @@ def quantize_and_dispatch_to_device(
198
  block,
199
  num_layers_to_quantize,
200
  num_quanted,
201
- quantization_dtype=quantization_dtype,
202
  )
203
 
204
  to_gpu_extras = [
@@ -221,10 +252,11 @@ def quantize_and_dispatch_to_device(
221
  block.compile()
222
  secho(f"Compiled block {i}", fg="green")
223
 
224
- _simple_swap_linears(flow_model)
 
225
  for extra in to_gpu_extras:
226
  m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
227
- if compile_blocks:
228
  if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
229
  m_extra.compile()
230
  secho(
@@ -232,10 +264,11 @@ def quantize_and_dispatch_to_device(
232
  fg="green",
233
  )
234
  elif quantize_extras:
235
- _full_quant(
236
- m_extra,
237
- current_quants=num_quanted,
238
- max_quants=num_layers_to_quantize,
239
- quantization_dtype=quantization_dtype,
240
- )
 
241
  return flow_model
 
5
  from click import secho
6
  from cublas_ops import CublasLinear
7
 
8
+ from quanto import (
9
+ QModuleMixin,
10
+ quantize_module,
11
+ QLinear,
12
+ QConv2d,
13
+ QLayerNorm,
14
+ )
15
+ from quanto.tensor import Optimizer, qtype, qfloat8, qint4, qint8
16
  from torch import nn
17
 
18
 
19
+ class QuantizationDtype:
20
+ qfloat8 = "qfloat8"
21
+ qint2 = "qint2"
22
+ qint4 = "qint4"
23
+ qint8 = "qint8"
24
+
25
+
26
+ def into_qtype(qtype: QuantizationDtype) -> qtype:
27
+ if qtype == QuantizationDtype.qfloat8:
28
+ return qfloat8
29
+ elif qtype == QuantizationDtype.qint4:
30
+ return qint4
31
+ elif qtype == QuantizationDtype.qint8:
32
+ return qint8
33
+ else:
34
+ raise ValueError(f"Unknown qtype: {qtype}")
35
+
36
+
37
  def _set_module_by_name(parent_module, name, child_module):
38
  module_names = name.split(".")
39
  if len(module_names) == 1:
 
145
 
146
  def _simple_swap_linears(model: nn.Module, root_name: str = ""):
147
  for name, module in model.named_children():
148
+ if (
149
+ _is_linear(module)
150
+ and hasattr(module, "weight")
151
+ and module.weight is not None
152
+ and module.weight.data is not None
153
+ ):
154
  weights = module.weight.data
155
  bias = None
156
  if module.bias is not None:
 
184
  if current_quants < max_quants:
185
  current_quants += _quantize(model, quantization_dtype)
186
  _freeze(model)
187
+ print(f"Quantized {current_quants} modules with {quantization_dtype}")
188
  return current_quants
189
 
190
 
 
203
  flux_device: torch.device = torch.device("cuda"),
204
  flux_dtype: torch.dtype = torch.float16,
205
  num_layers_to_quantize: int = 20,
206
+ quantization_dtype: QuantizationDtype = QuantizationDtype.qfloat8,
207
  compile_blocks: bool = True,
208
  compile_extras: bool = True,
209
  quantize_extras: bool = False,
210
+ replace_linears: bool = True,
211
  ):
212
+ quant_type = into_qtype(quantization_dtype)
213
  num_quanted = 0
214
  flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
215
  for block in flow_model.single_blocks:
 
219
  block,
220
  num_layers_to_quantize,
221
  num_quanted,
222
+ quantization_dtype=quant_type,
223
  )
224
 
225
  for block in flow_model.double_blocks:
 
229
  block,
230
  num_layers_to_quantize,
231
  num_quanted,
232
+ quantization_dtype=quant_type,
233
  )
234
 
235
  to_gpu_extras = [
 
252
  block.compile()
253
  secho(f"Compiled block {i}", fg="green")
254
 
255
+ if replace_linears:
256
+ _simple_swap_linears(flow_model)
257
  for extra in to_gpu_extras:
258
  m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
259
+ if compile_extras:
260
  if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
261
  m_extra.compile()
262
  secho(
 
264
  fg="green",
265
  )
266
  elif quantize_extras:
267
+ if not isinstance(m_extra, nn.Linear):
268
+ _full_quant(
269
+ m_extra,
270
+ current_quants=num_quanted,
271
+ max_quants=num_layers_to_quantize,
272
+ quantization_dtype=quantization_dtype,
273
+ )
274
  return flow_model
util.py CHANGED
@@ -18,6 +18,13 @@ class ModelVersion(StrEnum):
18
  flux_schnell = "flux-schnell"
19
 
20
 
 
 
 
 
 
 
 
21
  class ModelSpec(BaseModel):
22
  version: ModelVersion
23
  params: FluxParams
@@ -39,6 +46,10 @@ class ModelSpec(BaseModel):
39
  quantize_extras: bool = False
40
  compile_extras: bool = False
41
  compile_blocks: bool = False
 
 
 
 
42
 
43
  model_config: ConfigDict = {
44
  "arbitrary_types_allowed": True,
@@ -199,13 +210,15 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
199
  "openai/clip-vit-large-patch14",
200
  max_length=77,
201
  torch_dtype=into_dtype(config.text_enc_dtype),
202
- device=into_device(config.text_enc_device),
 
203
  )
204
  t5 = HFEmbedder(
205
  config.text_enc_path,
206
  max_length=config.text_enc_max_length,
207
  torch_dtype=into_dtype(config.text_enc_dtype),
208
  device=into_device(config.text_enc_device).index or 0,
 
209
  )
210
  return clip, t5
211
 
@@ -213,12 +226,22 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
213
  def load_autoencoder(config: ModelSpec) -> AutoEncoder:
214
  ckpt_path = config.ae_path
215
  with torch.device("meta" if ckpt_path is not None else config.ae_device):
216
- ae = AutoEncoder(config.ae_params)
217
 
218
  if ckpt_path is not None:
219
  sd = load_sft(ckpt_path, device=str(config.ae_device))
220
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
221
  print_load_warning(missing, unexpected)
 
 
 
 
 
 
 
 
 
 
222
  return ae
223
 
224
 
 
18
  flux_schnell = "flux-schnell"
19
 
20
 
21
+ class QuantizationDtype(StrEnum):
22
+ qfloat8 = "qfloat8"
23
+ qint2 = "qint2"
24
+ qint4 = "qint4"
25
+ qint8 = "qint8"
26
+
27
+
28
  class ModelSpec(BaseModel):
29
  version: ModelVersion
30
  params: FluxParams
 
46
  quantize_extras: bool = False
47
  compile_extras: bool = False
48
  compile_blocks: bool = False
49
+ flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
50
+ text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
51
+ ae_quantization_dtype: Optional[QuantizationDtype] = None
52
+ clip_quantization_dtype: Optional[QuantizationDtype] = None
53
 
54
  model_config: ConfigDict = {
55
  "arbitrary_types_allowed": True,
 
210
  "openai/clip-vit-large-patch14",
211
  max_length=77,
212
  torch_dtype=into_dtype(config.text_enc_dtype),
213
+ device=into_device(config.text_enc_device).index or 0,
214
+ quantization_dtype=config.clip_quantization_dtype,
215
  )
216
  t5 = HFEmbedder(
217
  config.text_enc_path,
218
  max_length=config.text_enc_max_length,
219
  torch_dtype=into_dtype(config.text_enc_dtype),
220
  device=into_device(config.text_enc_device).index or 0,
221
+ quantization_dtype=config.text_enc_quantization_dtype,
222
  )
223
  return clip, t5
224
 
 
226
  def load_autoencoder(config: ModelSpec) -> AutoEncoder:
227
  ckpt_path = config.ae_path
228
  with torch.device("meta" if ckpt_path is not None else config.ae_device):
229
+ ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
230
 
231
  if ckpt_path is not None:
232
  sd = load_sft(ckpt_path, device=str(config.ae_device))
233
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
234
  print_load_warning(missing, unexpected)
235
+ if config.ae_quantization_dtype is not None:
236
+ from quantize_swap_and_dispatch import _full_quant, into_qtype
237
+
238
+ ae.to(into_device(config.ae_device))
239
+ _full_quant(
240
+ ae,
241
+ max_quants=8000,
242
+ current_quants=0,
243
+ quantization_dtype=into_qtype(config.ae_quantization_dtype),
244
+ )
245
  return ae
246
 
247