Updates..
Browse files- configs/config-dev-gigaquant.json +9 -3
- configs/config-dev.json +2 -0
- flux_pipeline.py +2 -0
- modules/conditioner.py +68 -9
- modules/flux_model.py +0 -1
- quantize_swap_and_dispatch.py +48 -15
- util.py +25 -2
configs/config-dev-gigaquant.json
CHANGED
@@ -41,12 +41,18 @@
|
|
41 |
"repo_ae": "ae.sft",
|
42 |
"text_enc_max_length": 512,
|
43 |
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
-
"text_enc_device": "cuda:
|
45 |
-
"ae_device": "cuda:
|
46 |
"flux_device": "cuda:0",
|
47 |
"flow_dtype": "float16",
|
48 |
"ae_dtype": "bfloat16",
|
49 |
"text_enc_dtype": "bfloat16",
|
50 |
-
"num_to_quant":
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
"quantize_extras": true
|
52 |
}
|
|
|
41 |
"repo_ae": "ae.sft",
|
42 |
"text_enc_max_length": 512,
|
43 |
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
"flux_device": "cuda:0",
|
47 |
"flow_dtype": "float16",
|
48 |
"ae_dtype": "bfloat16",
|
49 |
"text_enc_dtype": "bfloat16",
|
50 |
+
"num_to_quant": 220,
|
51 |
+
"flow_quantization_dtype": "qint4",
|
52 |
+
"text_enc_quantization_dtype": "qint4",
|
53 |
+
"ae_quantization_dtype": "qint4",
|
54 |
+
"clip_quantization_dtype": "qint4",
|
55 |
+
"compile_extras": false,
|
56 |
+
"compile_blocks": false,
|
57 |
"quantize_extras": true
|
58 |
}
|
configs/config-dev.json
CHANGED
@@ -47,6 +47,8 @@
|
|
47 |
"flow_dtype": "float16",
|
48 |
"ae_dtype": "bfloat16",
|
49 |
"text_enc_dtype": "bfloat16",
|
|
|
|
|
50 |
"num_to_quant": 22,
|
51 |
"compile_extras": false,
|
52 |
"compile_blocks": false
|
|
|
47 |
"flow_dtype": "float16",
|
48 |
"ae_dtype": "bfloat16",
|
49 |
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
"num_to_quant": 22,
|
53 |
"compile_extras": false,
|
54 |
"compile_blocks": false
|
flux_pipeline.py
CHANGED
@@ -394,6 +394,7 @@ class FluxPipeline:
|
|
394 |
from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
|
395 |
|
396 |
with torch.inference_mode():
|
|
|
397 |
|
398 |
models = load_models_from_config(config)
|
399 |
config = models.config
|
@@ -413,6 +414,7 @@ class FluxPipeline:
|
|
413 |
compile_extras=config.compile_extras,
|
414 |
compile_blocks=config.compile_blocks,
|
415 |
quantize_extras=config.quantize_extras,
|
|
|
416 |
)
|
417 |
|
418 |
return cls(
|
|
|
394 |
from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
|
395 |
|
396 |
with torch.inference_mode():
|
397 |
+
print("flow_quantization_dtype", config.flow_quantization_dtype)
|
398 |
|
399 |
models = load_models_from_config(config)
|
400 |
config = models.config
|
|
|
414 |
compile_extras=config.compile_extras,
|
415 |
compile_blocks=config.compile_blocks,
|
416 |
quantize_extras=config.quantize_extras,
|
417 |
+
quantization_dtype=config.flow_quantization_dtype,
|
418 |
)
|
419 |
|
420 |
return cls(
|
modules/conditioner.py
CHANGED
@@ -1,37 +1,85 @@
|
|
1 |
-
|
|
|
2 |
import torch
|
3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class HFEmbedder(nn.Module):
|
9 |
def __init__(
|
10 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
11 |
):
|
12 |
super().__init__()
|
13 |
self.is_clip = version.startswith("openai")
|
14 |
self.max_length = max_length
|
15 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
|
|
|
|
|
|
16 |
|
17 |
if self.is_clip:
|
18 |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
19 |
version, max_length=max_length
|
20 |
)
|
|
|
21 |
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
22 |
-
version,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
)
|
24 |
-
|
25 |
else:
|
26 |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
27 |
version, max_length=max_length
|
28 |
)
|
29 |
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
30 |
version,
|
31 |
-
**hf_kwargs,
|
32 |
device_map={"": device},
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
35 |
),
|
36 |
)
|
37 |
|
@@ -51,3 +99,14 @@ class HFEmbedder(nn.Module):
|
|
51 |
output_hidden_states=False,
|
52 |
)
|
53 |
return outputs[self.output_key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import torch
|
4 |
+
from pydash import max_
|
5 |
+
from quanto import freeze, qfloat8, qint2, qint4, qint8, quantize
|
6 |
+
from quanto.nn.qmodule import _QMODULE_TABLE
|
7 |
+
from safetensors.torch import load_file, load_model, save_model
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from transformers import (
|
10 |
+
CLIPTextModel,
|
11 |
+
CLIPTokenizer,
|
12 |
+
T5EncoderModel,
|
13 |
+
T5Tokenizer,
|
14 |
+
__version__,
|
15 |
+
)
|
16 |
+
from transformers.utils.quantization_config import QuantoConfig
|
17 |
+
|
18 |
+
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
|
19 |
|
20 |
+
|
21 |
+
def into_quantization_name(quantization_dtype: str) -> str:
|
22 |
+
if quantization_dtype == "qfloat8":
|
23 |
+
return "float8"
|
24 |
+
elif quantization_dtype == "qint4":
|
25 |
+
return "int4"
|
26 |
+
elif quantization_dtype == "qint8":
|
27 |
+
return "int8"
|
28 |
+
elif quantization_dtype == "qint2":
|
29 |
+
return "int2"
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
|
32 |
|
33 |
|
34 |
class HFEmbedder(nn.Module):
|
35 |
def __init__(
|
36 |
+
self,
|
37 |
+
version: str,
|
38 |
+
max_length: int,
|
39 |
+
device: torch.device | int,
|
40 |
+
quantization_dtype: str | None = None,
|
41 |
+
**hf_kwargs,
|
42 |
):
|
43 |
super().__init__()
|
44 |
self.is_clip = version.startswith("openai")
|
45 |
self.max_length = max_length
|
46 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
47 |
+
quant_name = (
|
48 |
+
into_quantization_name(quantization_dtype) if quantization_dtype else None
|
49 |
+
)
|
50 |
|
51 |
if self.is_clip:
|
52 |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
53 |
version, max_length=max_length
|
54 |
)
|
55 |
+
|
56 |
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
57 |
+
version,
|
58 |
+
**hf_kwargs,
|
59 |
+
quantization_config=(
|
60 |
+
QuantoConfig(
|
61 |
+
weights=quant_name,
|
62 |
+
)
|
63 |
+
if quant_name
|
64 |
+
else None
|
65 |
+
),
|
66 |
+
device_map={"": device},
|
67 |
)
|
68 |
+
|
69 |
else:
|
70 |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
71 |
version, max_length=max_length
|
72 |
)
|
73 |
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
74 |
version,
|
|
|
75 |
device_map={"": device},
|
76 |
+
**hf_kwargs,
|
77 |
+
quantization_config=(
|
78 |
+
QuantoConfig(
|
79 |
+
weights=quant_name,
|
80 |
+
)
|
81 |
+
if quant_name
|
82 |
+
else None
|
83 |
),
|
84 |
)
|
85 |
|
|
|
99 |
output_hidden_states=False,
|
100 |
)
|
101 |
return outputs[self.output_key]
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
model = HFEmbedder(
|
106 |
+
"city96/t5-v1_1-xxl-encoder-bf16",
|
107 |
+
max_length=512,
|
108 |
+
device=0,
|
109 |
+
quantization_dtype="qfloat8",
|
110 |
+
)
|
111 |
+
o = model(["hello"])
|
112 |
+
print(o)
|
modules/flux_model.py
CHANGED
@@ -13,7 +13,6 @@ import math
|
|
13 |
from torch import Tensor, nn
|
14 |
from torch._dynamo import config
|
15 |
from torch._inductor import config as ind_config
|
16 |
-
from xformers.ops import memory_efficient_attention_forward
|
17 |
from pydantic import BaseModel
|
18 |
from torch.nn import functional as F
|
19 |
|
|
|
13 |
from torch import Tensor, nn
|
14 |
from torch._dynamo import config
|
15 |
from torch._inductor import config as ind_config
|
|
|
16 |
from pydantic import BaseModel
|
17 |
from torch.nn import functional as F
|
18 |
|
quantize_swap_and_dispatch.py
CHANGED
@@ -5,11 +5,35 @@ import torch
|
|
5 |
from click import secho
|
6 |
from cublas_ops import CublasLinear
|
7 |
|
8 |
-
from quanto
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from torch import nn
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def _set_module_by_name(parent_module, name, child_module):
|
14 |
module_names = name.split(".")
|
15 |
if len(module_names) == 1:
|
@@ -121,7 +145,12 @@ def _is_block_compilable(module: nn.Module) -> bool:
|
|
121 |
|
122 |
def _simple_swap_linears(model: nn.Module, root_name: str = ""):
|
123 |
for name, module in model.named_children():
|
124 |
-
if
|
|
|
|
|
|
|
|
|
|
|
125 |
weights = module.weight.data
|
126 |
bias = None
|
127 |
if module.bias is not None:
|
@@ -155,7 +184,7 @@ def _full_quant(
|
|
155 |
if current_quants < max_quants:
|
156 |
current_quants += _quantize(model, quantization_dtype)
|
157 |
_freeze(model)
|
158 |
-
print(f"Quantized {current_quants} modules")
|
159 |
return current_quants
|
160 |
|
161 |
|
@@ -174,11 +203,13 @@ def quantize_and_dispatch_to_device(
|
|
174 |
flux_device: torch.device = torch.device("cuda"),
|
175 |
flux_dtype: torch.dtype = torch.float16,
|
176 |
num_layers_to_quantize: int = 20,
|
177 |
-
quantization_dtype:
|
178 |
compile_blocks: bool = True,
|
179 |
compile_extras: bool = True,
|
180 |
quantize_extras: bool = False,
|
|
|
181 |
):
|
|
|
182 |
num_quanted = 0
|
183 |
flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
|
184 |
for block in flow_model.single_blocks:
|
@@ -188,7 +219,7 @@ def quantize_and_dispatch_to_device(
|
|
188 |
block,
|
189 |
num_layers_to_quantize,
|
190 |
num_quanted,
|
191 |
-
quantization_dtype=
|
192 |
)
|
193 |
|
194 |
for block in flow_model.double_blocks:
|
@@ -198,7 +229,7 @@ def quantize_and_dispatch_to_device(
|
|
198 |
block,
|
199 |
num_layers_to_quantize,
|
200 |
num_quanted,
|
201 |
-
quantization_dtype=
|
202 |
)
|
203 |
|
204 |
to_gpu_extras = [
|
@@ -221,10 +252,11 @@ def quantize_and_dispatch_to_device(
|
|
221 |
block.compile()
|
222 |
secho(f"Compiled block {i}", fg="green")
|
223 |
|
224 |
-
|
|
|
225 |
for extra in to_gpu_extras:
|
226 |
m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
|
227 |
-
if
|
228 |
if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
|
229 |
m_extra.compile()
|
230 |
secho(
|
@@ -232,10 +264,11 @@ def quantize_and_dispatch_to_device(
|
|
232 |
fg="green",
|
233 |
)
|
234 |
elif quantize_extras:
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
241 |
return flow_model
|
|
|
5 |
from click import secho
|
6 |
from cublas_ops import CublasLinear
|
7 |
|
8 |
+
from quanto import (
|
9 |
+
QModuleMixin,
|
10 |
+
quantize_module,
|
11 |
+
QLinear,
|
12 |
+
QConv2d,
|
13 |
+
QLayerNorm,
|
14 |
+
)
|
15 |
+
from quanto.tensor import Optimizer, qtype, qfloat8, qint4, qint8
|
16 |
from torch import nn
|
17 |
|
18 |
|
19 |
+
class QuantizationDtype:
|
20 |
+
qfloat8 = "qfloat8"
|
21 |
+
qint2 = "qint2"
|
22 |
+
qint4 = "qint4"
|
23 |
+
qint8 = "qint8"
|
24 |
+
|
25 |
+
|
26 |
+
def into_qtype(qtype: QuantizationDtype) -> qtype:
|
27 |
+
if qtype == QuantizationDtype.qfloat8:
|
28 |
+
return qfloat8
|
29 |
+
elif qtype == QuantizationDtype.qint4:
|
30 |
+
return qint4
|
31 |
+
elif qtype == QuantizationDtype.qint8:
|
32 |
+
return qint8
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Unknown qtype: {qtype}")
|
35 |
+
|
36 |
+
|
37 |
def _set_module_by_name(parent_module, name, child_module):
|
38 |
module_names = name.split(".")
|
39 |
if len(module_names) == 1:
|
|
|
145 |
|
146 |
def _simple_swap_linears(model: nn.Module, root_name: str = ""):
|
147 |
for name, module in model.named_children():
|
148 |
+
if (
|
149 |
+
_is_linear(module)
|
150 |
+
and hasattr(module, "weight")
|
151 |
+
and module.weight is not None
|
152 |
+
and module.weight.data is not None
|
153 |
+
):
|
154 |
weights = module.weight.data
|
155 |
bias = None
|
156 |
if module.bias is not None:
|
|
|
184 |
if current_quants < max_quants:
|
185 |
current_quants += _quantize(model, quantization_dtype)
|
186 |
_freeze(model)
|
187 |
+
print(f"Quantized {current_quants} modules with {quantization_dtype}")
|
188 |
return current_quants
|
189 |
|
190 |
|
|
|
203 |
flux_device: torch.device = torch.device("cuda"),
|
204 |
flux_dtype: torch.dtype = torch.float16,
|
205 |
num_layers_to_quantize: int = 20,
|
206 |
+
quantization_dtype: QuantizationDtype = QuantizationDtype.qfloat8,
|
207 |
compile_blocks: bool = True,
|
208 |
compile_extras: bool = True,
|
209 |
quantize_extras: bool = False,
|
210 |
+
replace_linears: bool = True,
|
211 |
):
|
212 |
+
quant_type = into_qtype(quantization_dtype)
|
213 |
num_quanted = 0
|
214 |
flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
|
215 |
for block in flow_model.single_blocks:
|
|
|
219 |
block,
|
220 |
num_layers_to_quantize,
|
221 |
num_quanted,
|
222 |
+
quantization_dtype=quant_type,
|
223 |
)
|
224 |
|
225 |
for block in flow_model.double_blocks:
|
|
|
229 |
block,
|
230 |
num_layers_to_quantize,
|
231 |
num_quanted,
|
232 |
+
quantization_dtype=quant_type,
|
233 |
)
|
234 |
|
235 |
to_gpu_extras = [
|
|
|
252 |
block.compile()
|
253 |
secho(f"Compiled block {i}", fg="green")
|
254 |
|
255 |
+
if replace_linears:
|
256 |
+
_simple_swap_linears(flow_model)
|
257 |
for extra in to_gpu_extras:
|
258 |
m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
|
259 |
+
if compile_extras:
|
260 |
if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
|
261 |
m_extra.compile()
|
262 |
secho(
|
|
|
264 |
fg="green",
|
265 |
)
|
266 |
elif quantize_extras:
|
267 |
+
if not isinstance(m_extra, nn.Linear):
|
268 |
+
_full_quant(
|
269 |
+
m_extra,
|
270 |
+
current_quants=num_quanted,
|
271 |
+
max_quants=num_layers_to_quantize,
|
272 |
+
quantization_dtype=quantization_dtype,
|
273 |
+
)
|
274 |
return flow_model
|
util.py
CHANGED
@@ -18,6 +18,13 @@ class ModelVersion(StrEnum):
|
|
18 |
flux_schnell = "flux-schnell"
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class ModelSpec(BaseModel):
|
22 |
version: ModelVersion
|
23 |
params: FluxParams
|
@@ -39,6 +46,10 @@ class ModelSpec(BaseModel):
|
|
39 |
quantize_extras: bool = False
|
40 |
compile_extras: bool = False
|
41 |
compile_blocks: bool = False
|
|
|
|
|
|
|
|
|
42 |
|
43 |
model_config: ConfigDict = {
|
44 |
"arbitrary_types_allowed": True,
|
@@ -199,13 +210,15 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
|
|
199 |
"openai/clip-vit-large-patch14",
|
200 |
max_length=77,
|
201 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
202 |
-
device=into_device(config.text_enc_device),
|
|
|
203 |
)
|
204 |
t5 = HFEmbedder(
|
205 |
config.text_enc_path,
|
206 |
max_length=config.text_enc_max_length,
|
207 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
208 |
device=into_device(config.text_enc_device).index or 0,
|
|
|
209 |
)
|
210 |
return clip, t5
|
211 |
|
@@ -213,12 +226,22 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
|
|
213 |
def load_autoencoder(config: ModelSpec) -> AutoEncoder:
|
214 |
ckpt_path = config.ae_path
|
215 |
with torch.device("meta" if ckpt_path is not None else config.ae_device):
|
216 |
-
ae = AutoEncoder(config.ae_params)
|
217 |
|
218 |
if ckpt_path is not None:
|
219 |
sd = load_sft(ckpt_path, device=str(config.ae_device))
|
220 |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
221 |
print_load_warning(missing, unexpected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
return ae
|
223 |
|
224 |
|
|
|
18 |
flux_schnell = "flux-schnell"
|
19 |
|
20 |
|
21 |
+
class QuantizationDtype(StrEnum):
|
22 |
+
qfloat8 = "qfloat8"
|
23 |
+
qint2 = "qint2"
|
24 |
+
qint4 = "qint4"
|
25 |
+
qint8 = "qint8"
|
26 |
+
|
27 |
+
|
28 |
class ModelSpec(BaseModel):
|
29 |
version: ModelVersion
|
30 |
params: FluxParams
|
|
|
46 |
quantize_extras: bool = False
|
47 |
compile_extras: bool = False
|
48 |
compile_blocks: bool = False
|
49 |
+
flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
|
50 |
+
text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
|
51 |
+
ae_quantization_dtype: Optional[QuantizationDtype] = None
|
52 |
+
clip_quantization_dtype: Optional[QuantizationDtype] = None
|
53 |
|
54 |
model_config: ConfigDict = {
|
55 |
"arbitrary_types_allowed": True,
|
|
|
210 |
"openai/clip-vit-large-patch14",
|
211 |
max_length=77,
|
212 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
213 |
+
device=into_device(config.text_enc_device).index or 0,
|
214 |
+
quantization_dtype=config.clip_quantization_dtype,
|
215 |
)
|
216 |
t5 = HFEmbedder(
|
217 |
config.text_enc_path,
|
218 |
max_length=config.text_enc_max_length,
|
219 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
220 |
device=into_device(config.text_enc_device).index or 0,
|
221 |
+
quantization_dtype=config.text_enc_quantization_dtype,
|
222 |
)
|
223 |
return clip, t5
|
224 |
|
|
|
226 |
def load_autoencoder(config: ModelSpec) -> AutoEncoder:
|
227 |
ckpt_path = config.ae_path
|
228 |
with torch.device("meta" if ckpt_path is not None else config.ae_device):
|
229 |
+
ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
|
230 |
|
231 |
if ckpt_path is not None:
|
232 |
sd = load_sft(ckpt_path, device=str(config.ae_device))
|
233 |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
234 |
print_load_warning(missing, unexpected)
|
235 |
+
if config.ae_quantization_dtype is not None:
|
236 |
+
from quantize_swap_and_dispatch import _full_quant, into_qtype
|
237 |
+
|
238 |
+
ae.to(into_device(config.ae_device))
|
239 |
+
_full_quant(
|
240 |
+
ae,
|
241 |
+
max_quants=8000,
|
242 |
+
current_quants=0,
|
243 |
+
quantization_dtype=into_qtype(config.ae_quantization_dtype),
|
244 |
+
)
|
245 |
return ae
|
246 |
|
247 |
|