Updates..
Browse files- configs/config-dev-gigaquant.json +9 -3
 - configs/config-dev.json +2 -0
 - flux_pipeline.py +2 -0
 - modules/conditioner.py +68 -9
 - modules/flux_model.py +0 -1
 - quantize_swap_and_dispatch.py +48 -15
 - util.py +25 -2
 
    	
        configs/config-dev-gigaquant.json
    CHANGED
    
    | 
         @@ -41,12 +41,18 @@ 
     | 
|
| 41 | 
         
             
              "repo_ae": "ae.sft",
         
     | 
| 42 | 
         
             
              "text_enc_max_length": 512,
         
     | 
| 43 | 
         
             
              "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
         
     | 
| 44 | 
         
            -
              "text_enc_device": "cuda: 
     | 
| 45 | 
         
            -
              "ae_device": "cuda: 
     | 
| 46 | 
         
             
              "flux_device": "cuda:0",
         
     | 
| 47 | 
         
             
              "flow_dtype": "float16",
         
     | 
| 48 | 
         
             
              "ae_dtype": "bfloat16",
         
     | 
| 49 | 
         
             
              "text_enc_dtype": "bfloat16",
         
     | 
| 50 | 
         
            -
              "num_to_quant":  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 51 | 
         
             
              "quantize_extras": true
         
     | 
| 52 | 
         
             
            }
         
     | 
| 
         | 
|
| 41 | 
         
             
              "repo_ae": "ae.sft",
         
     | 
| 42 | 
         
             
              "text_enc_max_length": 512,
         
     | 
| 43 | 
         
             
              "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
         
     | 
| 44 | 
         
            +
              "text_enc_device": "cuda:0",
         
     | 
| 45 | 
         
            +
              "ae_device": "cuda:0",
         
     | 
| 46 | 
         
             
              "flux_device": "cuda:0",
         
     | 
| 47 | 
         
             
              "flow_dtype": "float16",
         
     | 
| 48 | 
         
             
              "ae_dtype": "bfloat16",
         
     | 
| 49 | 
         
             
              "text_enc_dtype": "bfloat16",
         
     | 
| 50 | 
         
            +
              "num_to_quant": 220,
         
     | 
| 51 | 
         
            +
              "flow_quantization_dtype": "qint4",
         
     | 
| 52 | 
         
            +
              "text_enc_quantization_dtype": "qint4",
         
     | 
| 53 | 
         
            +
              "ae_quantization_dtype": "qint4",
         
     | 
| 54 | 
         
            +
              "clip_quantization_dtype": "qint4",
         
     | 
| 55 | 
         
            +
              "compile_extras": false,
         
     | 
| 56 | 
         
            +
              "compile_blocks": false,
         
     | 
| 57 | 
         
             
              "quantize_extras": true
         
     | 
| 58 | 
         
             
            }
         
     | 
    	
        configs/config-dev.json
    CHANGED
    
    | 
         @@ -47,6 +47,8 @@ 
     | 
|
| 47 | 
         
             
              "flow_dtype": "float16",
         
     | 
| 48 | 
         
             
              "ae_dtype": "bfloat16",
         
     | 
| 49 | 
         
             
              "text_enc_dtype": "bfloat16",
         
     | 
| 
         | 
|
| 
         | 
|
| 50 | 
         
             
              "num_to_quant": 22,
         
     | 
| 51 | 
         
             
              "compile_extras": false,
         
     | 
| 52 | 
         
             
              "compile_blocks": false
         
     | 
| 
         | 
|
| 47 | 
         
             
              "flow_dtype": "float16",
         
     | 
| 48 | 
         
             
              "ae_dtype": "bfloat16",
         
     | 
| 49 | 
         
             
              "text_enc_dtype": "bfloat16",
         
     | 
| 50 | 
         
            +
              "flow_quantization_dtype": "qfloat8",
         
     | 
| 51 | 
         
            +
              "text_enc_quantization_dtype": "qfloat8",
         
     | 
| 52 | 
         
             
              "num_to_quant": 22,
         
     | 
| 53 | 
         
             
              "compile_extras": false,
         
     | 
| 54 | 
         
             
              "compile_blocks": false
         
     | 
    	
        flux_pipeline.py
    CHANGED
    
    | 
         @@ -394,6 +394,7 @@ class FluxPipeline: 
     | 
|
| 394 | 
         
             
                    from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
         
     | 
| 395 | 
         | 
| 396 | 
         
             
                    with torch.inference_mode():
         
     | 
| 
         | 
|
| 397 | 
         | 
| 398 | 
         
             
                        models = load_models_from_config(config)
         
     | 
| 399 | 
         
             
                        config = models.config
         
     | 
| 
         @@ -413,6 +414,7 @@ class FluxPipeline: 
     | 
|
| 413 | 
         
             
                            compile_extras=config.compile_extras,
         
     | 
| 414 | 
         
             
                            compile_blocks=config.compile_blocks,
         
     | 
| 415 | 
         
             
                            quantize_extras=config.quantize_extras,
         
     | 
| 
         | 
|
| 416 | 
         
             
                        )
         
     | 
| 417 | 
         | 
| 418 | 
         
             
                    return cls(
         
     | 
| 
         | 
|
| 394 | 
         
             
                    from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
         
     | 
| 395 | 
         | 
| 396 | 
         
             
                    with torch.inference_mode():
         
     | 
| 397 | 
         
            +
                        print("flow_quantization_dtype", config.flow_quantization_dtype)
         
     | 
| 398 | 
         | 
| 399 | 
         
             
                        models = load_models_from_config(config)
         
     | 
| 400 | 
         
             
                        config = models.config
         
     | 
| 
         | 
|
| 414 | 
         
             
                            compile_extras=config.compile_extras,
         
     | 
| 415 | 
         
             
                            compile_blocks=config.compile_blocks,
         
     | 
| 416 | 
         
             
                            quantize_extras=config.quantize_extras,
         
     | 
| 417 | 
         
            +
                            quantization_dtype=config.flow_quantization_dtype,
         
     | 
| 418 | 
         
             
                        )
         
     | 
| 419 | 
         | 
| 420 | 
         
             
                    return cls(
         
     | 
    	
        modules/conditioner.py
    CHANGED
    
    | 
         @@ -1,37 +1,85 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 
         | 
|
| 2 | 
         
             
            import torch
         
     | 
| 3 | 
         
            -
            from  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         | 
| 7 | 
         | 
| 8 | 
         
             
            class HFEmbedder(nn.Module):
         
     | 
| 9 | 
         
             
                def __init__(
         
     | 
| 10 | 
         
            -
                    self, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
                ):
         
     | 
| 12 | 
         
             
                    super().__init__()
         
     | 
| 13 | 
         
             
                    self.is_clip = version.startswith("openai")
         
     | 
| 14 | 
         
             
                    self.max_length = max_length
         
     | 
| 15 | 
         
             
                    self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
                    if self.is_clip:
         
     | 
| 18 | 
         
             
                        self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
         
     | 
| 19 | 
         
             
                            version, max_length=max_length
         
     | 
| 20 | 
         
             
                        )
         
     | 
| 
         | 
|
| 21 | 
         
             
                        self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
         
     | 
| 22 | 
         
            -
                            version, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 23 | 
         
             
                        )
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
             
                    else:
         
     | 
| 26 | 
         
             
                        self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
         
     | 
| 27 | 
         
             
                            version, max_length=max_length
         
     | 
| 28 | 
         
             
                        )
         
     | 
| 29 | 
         
             
                        self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
         
     | 
| 30 | 
         
             
                            version,
         
     | 
| 31 | 
         
            -
                            **hf_kwargs,
         
     | 
| 32 | 
         
             
                            device_map={"": device},
         
     | 
| 33 | 
         
            -
                             
     | 
| 34 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 35 | 
         
             
                            ),
         
     | 
| 36 | 
         
             
                        )
         
     | 
| 37 | 
         | 
| 
         @@ -51,3 +99,14 @@ class HFEmbedder(nn.Module): 
     | 
|
| 51 | 
         
             
                        output_hidden_states=False,
         
     | 
| 52 | 
         
             
                    )
         
     | 
| 53 | 
         
             
                    return outputs[self.output_key]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
            +
            from pydash import max_
         
     | 
| 5 | 
         
            +
            from quanto import freeze, qfloat8, qint2, qint4, qint8, quantize
         
     | 
| 6 | 
         
            +
            from quanto.nn.qmodule import _QMODULE_TABLE
         
     | 
| 7 | 
         
            +
            from safetensors.torch import load_file, load_model, save_model
         
     | 
| 8 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 9 | 
         
            +
            from transformers import (
         
     | 
| 10 | 
         
            +
                CLIPTextModel,
         
     | 
| 11 | 
         
            +
                CLIPTokenizer,
         
     | 
| 12 | 
         
            +
                T5EncoderModel,
         
     | 
| 13 | 
         
            +
                T5Tokenizer,
         
     | 
| 14 | 
         
            +
                __version__,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
            from transformers.utils.quantization_config import QuantoConfig
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def into_quantization_name(quantization_dtype: str) -> str:
         
     | 
| 22 | 
         
            +
                if quantization_dtype == "qfloat8":
         
     | 
| 23 | 
         
            +
                    return "float8"
         
     | 
| 24 | 
         
            +
                elif quantization_dtype == "qint4":
         
     | 
| 25 | 
         
            +
                    return "int4"
         
     | 
| 26 | 
         
            +
                elif quantization_dtype == "qint8":
         
     | 
| 27 | 
         
            +
                    return "int8"
         
     | 
| 28 | 
         
            +
                elif quantization_dtype == "qint2":
         
     | 
| 29 | 
         
            +
                    return "int2"
         
     | 
| 30 | 
         
            +
                else:
         
     | 
| 31 | 
         
            +
                    raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
         
     | 
| 32 | 
         | 
| 33 | 
         | 
| 34 | 
         
             
            class HFEmbedder(nn.Module):
         
     | 
| 35 | 
         
             
                def __init__(
         
     | 
| 36 | 
         
            +
                    self,
         
     | 
| 37 | 
         
            +
                    version: str,
         
     | 
| 38 | 
         
            +
                    max_length: int,
         
     | 
| 39 | 
         
            +
                    device: torch.device | int,
         
     | 
| 40 | 
         
            +
                    quantization_dtype: str | None = None,
         
     | 
| 41 | 
         
            +
                    **hf_kwargs,
         
     | 
| 42 | 
         
             
                ):
         
     | 
| 43 | 
         
             
                    super().__init__()
         
     | 
| 44 | 
         
             
                    self.is_clip = version.startswith("openai")
         
     | 
| 45 | 
         
             
                    self.max_length = max_length
         
     | 
| 46 | 
         
             
                    self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
         
     | 
| 47 | 
         
            +
                    quant_name = (
         
     | 
| 48 | 
         
            +
                        into_quantization_name(quantization_dtype) if quantization_dtype else None
         
     | 
| 49 | 
         
            +
                    )
         
     | 
| 50 | 
         | 
| 51 | 
         
             
                    if self.is_clip:
         
     | 
| 52 | 
         
             
                        self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
         
     | 
| 53 | 
         
             
                            version, max_length=max_length
         
     | 
| 54 | 
         
             
                        )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
             
                        self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
         
     | 
| 57 | 
         
            +
                            version,
         
     | 
| 58 | 
         
            +
                            **hf_kwargs,
         
     | 
| 59 | 
         
            +
                            quantization_config=(
         
     | 
| 60 | 
         
            +
                                QuantoConfig(
         
     | 
| 61 | 
         
            +
                                    weights=quant_name,
         
     | 
| 62 | 
         
            +
                                )
         
     | 
| 63 | 
         
            +
                                if quant_name
         
     | 
| 64 | 
         
            +
                                else None
         
     | 
| 65 | 
         
            +
                            ),
         
     | 
| 66 | 
         
            +
                            device_map={"": device},
         
     | 
| 67 | 
         
             
                        )
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
             
                    else:
         
     | 
| 70 | 
         
             
                        self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
         
     | 
| 71 | 
         
             
                            version, max_length=max_length
         
     | 
| 72 | 
         
             
                        )
         
     | 
| 73 | 
         
             
                        self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
         
     | 
| 74 | 
         
             
                            version,
         
     | 
| 
         | 
|
| 75 | 
         
             
                            device_map={"": device},
         
     | 
| 76 | 
         
            +
                            **hf_kwargs,
         
     | 
| 77 | 
         
            +
                            quantization_config=(
         
     | 
| 78 | 
         
            +
                                QuantoConfig(
         
     | 
| 79 | 
         
            +
                                    weights=quant_name,
         
     | 
| 80 | 
         
            +
                                )
         
     | 
| 81 | 
         
            +
                                if quant_name
         
     | 
| 82 | 
         
            +
                                else None
         
     | 
| 83 | 
         
             
                            ),
         
     | 
| 84 | 
         
             
                        )
         
     | 
| 85 | 
         | 
| 
         | 
|
| 99 | 
         
             
                        output_hidden_states=False,
         
     | 
| 100 | 
         
             
                    )
         
     | 
| 101 | 
         
             
                    return outputs[self.output_key]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 105 | 
         
            +
                model = HFEmbedder(
         
     | 
| 106 | 
         
            +
                    "city96/t5-v1_1-xxl-encoder-bf16",
         
     | 
| 107 | 
         
            +
                    max_length=512,
         
     | 
| 108 | 
         
            +
                    device=0,
         
     | 
| 109 | 
         
            +
                    quantization_dtype="qfloat8",
         
     | 
| 110 | 
         
            +
                )
         
     | 
| 111 | 
         
            +
                o = model(["hello"])
         
     | 
| 112 | 
         
            +
                print(o)
         
     | 
    	
        modules/flux_model.py
    CHANGED
    
    | 
         @@ -13,7 +13,6 @@ import math 
     | 
|
| 13 | 
         
             
            from torch import Tensor, nn
         
     | 
| 14 | 
         
             
            from torch._dynamo import config
         
     | 
| 15 | 
         
             
            from torch._inductor import config as ind_config
         
     | 
| 16 | 
         
            -
            from xformers.ops import memory_efficient_attention_forward
         
     | 
| 17 | 
         
             
            from pydantic import BaseModel
         
     | 
| 18 | 
         
             
            from torch.nn import functional as F
         
     | 
| 19 | 
         | 
| 
         | 
|
| 13 | 
         
             
            from torch import Tensor, nn
         
     | 
| 14 | 
         
             
            from torch._dynamo import config
         
     | 
| 15 | 
         
             
            from torch._inductor import config as ind_config
         
     | 
| 
         | 
|
| 16 | 
         
             
            from pydantic import BaseModel
         
     | 
| 17 | 
         
             
            from torch.nn import functional as F
         
     | 
| 18 | 
         | 
    	
        quantize_swap_and_dispatch.py
    CHANGED
    
    | 
         @@ -5,11 +5,35 @@ import torch 
     | 
|
| 5 | 
         
             
            from click import secho
         
     | 
| 6 | 
         
             
            from cublas_ops import CublasLinear
         
     | 
| 7 | 
         | 
| 8 | 
         
            -
            from quanto 
     | 
| 9 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 10 | 
         
             
            from torch import nn
         
     | 
| 11 | 
         | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         
             
            def _set_module_by_name(parent_module, name, child_module):
         
     | 
| 14 | 
         
             
                module_names = name.split(".")
         
     | 
| 15 | 
         
             
                if len(module_names) == 1:
         
     | 
| 
         @@ -121,7 +145,12 @@ def _is_block_compilable(module: nn.Module) -> bool: 
     | 
|
| 121 | 
         | 
| 122 | 
         
             
            def _simple_swap_linears(model: nn.Module, root_name: str = ""):
         
     | 
| 123 | 
         
             
                for name, module in model.named_children():
         
     | 
| 124 | 
         
            -
                    if  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 125 | 
         
             
                        weights = module.weight.data
         
     | 
| 126 | 
         
             
                        bias = None
         
     | 
| 127 | 
         
             
                        if module.bias is not None:
         
     | 
| 
         @@ -155,7 +184,7 @@ def _full_quant( 
     | 
|
| 155 | 
         
             
                if current_quants < max_quants:
         
     | 
| 156 | 
         
             
                    current_quants += _quantize(model, quantization_dtype)
         
     | 
| 157 | 
         
             
                    _freeze(model)
         
     | 
| 158 | 
         
            -
                    print(f"Quantized {current_quants} modules")
         
     | 
| 159 | 
         
             
                return current_quants
         
     | 
| 160 | 
         | 
| 161 | 
         | 
| 
         @@ -174,11 +203,13 @@ def quantize_and_dispatch_to_device( 
     | 
|
| 174 | 
         
             
                flux_device: torch.device = torch.device("cuda"),
         
     | 
| 175 | 
         
             
                flux_dtype: torch.dtype = torch.float16,
         
     | 
| 176 | 
         
             
                num_layers_to_quantize: int = 20,
         
     | 
| 177 | 
         
            -
                quantization_dtype:  
     | 
| 178 | 
         
             
                compile_blocks: bool = True,
         
     | 
| 179 | 
         
             
                compile_extras: bool = True,
         
     | 
| 180 | 
         
             
                quantize_extras: bool = False,
         
     | 
| 
         | 
|
| 181 | 
         
             
            ):
         
     | 
| 
         | 
|
| 182 | 
         
             
                num_quanted = 0
         
     | 
| 183 | 
         
             
                flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
         
     | 
| 184 | 
         
             
                for block in flow_model.single_blocks:
         
     | 
| 
         @@ -188,7 +219,7 @@ def quantize_and_dispatch_to_device( 
     | 
|
| 188 | 
         
             
                            block,
         
     | 
| 189 | 
         
             
                            num_layers_to_quantize,
         
     | 
| 190 | 
         
             
                            num_quanted,
         
     | 
| 191 | 
         
            -
                            quantization_dtype= 
     | 
| 192 | 
         
             
                        )
         
     | 
| 193 | 
         | 
| 194 | 
         
             
                for block in flow_model.double_blocks:
         
     | 
| 
         @@ -198,7 +229,7 @@ def quantize_and_dispatch_to_device( 
     | 
|
| 198 | 
         
             
                            block,
         
     | 
| 199 | 
         
             
                            num_layers_to_quantize,
         
     | 
| 200 | 
         
             
                            num_quanted,
         
     | 
| 201 | 
         
            -
                            quantization_dtype= 
     | 
| 202 | 
         
             
                        )
         
     | 
| 203 | 
         | 
| 204 | 
         
             
                to_gpu_extras = [
         
     | 
| 
         @@ -221,10 +252,11 @@ def quantize_and_dispatch_to_device( 
     | 
|
| 221 | 
         
             
                            block.compile()
         
     | 
| 222 | 
         
             
                            secho(f"Compiled block {i}", fg="green")
         
     | 
| 223 | 
         | 
| 224 | 
         
            -
                 
     | 
| 
         | 
|
| 225 | 
         
             
                for extra in to_gpu_extras:
         
     | 
| 226 | 
         
             
                    m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
         
     | 
| 227 | 
         
            -
                    if  
     | 
| 228 | 
         
             
                        if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
         
     | 
| 229 | 
         
             
                            m_extra.compile()
         
     | 
| 230 | 
         
             
                            secho(
         
     | 
| 
         @@ -232,10 +264,11 @@ def quantize_and_dispatch_to_device( 
     | 
|
| 232 | 
         
             
                                fg="green",
         
     | 
| 233 | 
         
             
                            )
         
     | 
| 234 | 
         
             
                    elif quantize_extras:
         
     | 
| 235 | 
         
            -
                         
     | 
| 236 | 
         
            -
                             
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
             
     | 
| 239 | 
         
            -
             
     | 
| 240 | 
         
            -
             
     | 
| 
         | 
|
| 241 | 
         
             
                return flow_model
         
     | 
| 
         | 
|
| 5 | 
         
             
            from click import secho
         
     | 
| 6 | 
         
             
            from cublas_ops import CublasLinear
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            from quanto import (
         
     | 
| 9 | 
         
            +
                QModuleMixin,
         
     | 
| 10 | 
         
            +
                quantize_module,
         
     | 
| 11 | 
         
            +
                QLinear,
         
     | 
| 12 | 
         
            +
                QConv2d,
         
     | 
| 13 | 
         
            +
                QLayerNorm,
         
     | 
| 14 | 
         
            +
            )
         
     | 
| 15 | 
         
            +
            from quanto.tensor import Optimizer, qtype, qfloat8, qint4, qint8
         
     | 
| 16 | 
         
             
            from torch import nn
         
     | 
| 17 | 
         | 
| 18 | 
         | 
| 19 | 
         
            +
            class QuantizationDtype:
         
     | 
| 20 | 
         
            +
                qfloat8 = "qfloat8"
         
     | 
| 21 | 
         
            +
                qint2 = "qint2"
         
     | 
| 22 | 
         
            +
                qint4 = "qint4"
         
     | 
| 23 | 
         
            +
                qint8 = "qint8"
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def into_qtype(qtype: QuantizationDtype) -> qtype:
         
     | 
| 27 | 
         
            +
                if qtype == QuantizationDtype.qfloat8:
         
     | 
| 28 | 
         
            +
                    return qfloat8
         
     | 
| 29 | 
         
            +
                elif qtype == QuantizationDtype.qint4:
         
     | 
| 30 | 
         
            +
                    return qint4
         
     | 
| 31 | 
         
            +
                elif qtype == QuantizationDtype.qint8:
         
     | 
| 32 | 
         
            +
                    return qint8
         
     | 
| 33 | 
         
            +
                else:
         
     | 
| 34 | 
         
            +
                    raise ValueError(f"Unknown qtype: {qtype}")
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
             
            def _set_module_by_name(parent_module, name, child_module):
         
     | 
| 38 | 
         
             
                module_names = name.split(".")
         
     | 
| 39 | 
         
             
                if len(module_names) == 1:
         
     | 
| 
         | 
|
| 145 | 
         | 
| 146 | 
         
             
            def _simple_swap_linears(model: nn.Module, root_name: str = ""):
         
     | 
| 147 | 
         
             
                for name, module in model.named_children():
         
     | 
| 148 | 
         
            +
                    if (
         
     | 
| 149 | 
         
            +
                        _is_linear(module)
         
     | 
| 150 | 
         
            +
                        and hasattr(module, "weight")
         
     | 
| 151 | 
         
            +
                        and module.weight is not None
         
     | 
| 152 | 
         
            +
                        and module.weight.data is not None
         
     | 
| 153 | 
         
            +
                    ):
         
     | 
| 154 | 
         
             
                        weights = module.weight.data
         
     | 
| 155 | 
         
             
                        bias = None
         
     | 
| 156 | 
         
             
                        if module.bias is not None:
         
     | 
| 
         | 
|
| 184 | 
         
             
                if current_quants < max_quants:
         
     | 
| 185 | 
         
             
                    current_quants += _quantize(model, quantization_dtype)
         
     | 
| 186 | 
         
             
                    _freeze(model)
         
     | 
| 187 | 
         
            +
                    print(f"Quantized {current_quants} modules with {quantization_dtype}")
         
     | 
| 188 | 
         
             
                return current_quants
         
     | 
| 189 | 
         | 
| 190 | 
         | 
| 
         | 
|
| 203 | 
         
             
                flux_device: torch.device = torch.device("cuda"),
         
     | 
| 204 | 
         
             
                flux_dtype: torch.dtype = torch.float16,
         
     | 
| 205 | 
         
             
                num_layers_to_quantize: int = 20,
         
     | 
| 206 | 
         
            +
                quantization_dtype: QuantizationDtype = QuantizationDtype.qfloat8,
         
     | 
| 207 | 
         
             
                compile_blocks: bool = True,
         
     | 
| 208 | 
         
             
                compile_extras: bool = True,
         
     | 
| 209 | 
         
             
                quantize_extras: bool = False,
         
     | 
| 210 | 
         
            +
                replace_linears: bool = True,
         
     | 
| 211 | 
         
             
            ):
         
     | 
| 212 | 
         
            +
                quant_type = into_qtype(quantization_dtype)
         
     | 
| 213 | 
         
             
                num_quanted = 0
         
     | 
| 214 | 
         
             
                flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
         
     | 
| 215 | 
         
             
                for block in flow_model.single_blocks:
         
     | 
| 
         | 
|
| 219 | 
         
             
                            block,
         
     | 
| 220 | 
         
             
                            num_layers_to_quantize,
         
     | 
| 221 | 
         
             
                            num_quanted,
         
     | 
| 222 | 
         
            +
                            quantization_dtype=quant_type,
         
     | 
| 223 | 
         
             
                        )
         
     | 
| 224 | 
         | 
| 225 | 
         
             
                for block in flow_model.double_blocks:
         
     | 
| 
         | 
|
| 229 | 
         
             
                            block,
         
     | 
| 230 | 
         
             
                            num_layers_to_quantize,
         
     | 
| 231 | 
         
             
                            num_quanted,
         
     | 
| 232 | 
         
            +
                            quantization_dtype=quant_type,
         
     | 
| 233 | 
         
             
                        )
         
     | 
| 234 | 
         | 
| 235 | 
         
             
                to_gpu_extras = [
         
     | 
| 
         | 
|
| 252 | 
         
             
                            block.compile()
         
     | 
| 253 | 
         
             
                            secho(f"Compiled block {i}", fg="green")
         
     | 
| 254 | 
         | 
| 255 | 
         
            +
                if replace_linears:
         
     | 
| 256 | 
         
            +
                    _simple_swap_linears(flow_model)
         
     | 
| 257 | 
         
             
                for extra in to_gpu_extras:
         
     | 
| 258 | 
         
             
                    m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
         
     | 
| 259 | 
         
            +
                    if compile_extras:
         
     | 
| 260 | 
         
             
                        if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
         
     | 
| 261 | 
         
             
                            m_extra.compile()
         
     | 
| 262 | 
         
             
                            secho(
         
     | 
| 
         | 
|
| 264 | 
         
             
                                fg="green",
         
     | 
| 265 | 
         
             
                            )
         
     | 
| 266 | 
         
             
                    elif quantize_extras:
         
     | 
| 267 | 
         
            +
                        if not isinstance(m_extra, nn.Linear):
         
     | 
| 268 | 
         
            +
                            _full_quant(
         
     | 
| 269 | 
         
            +
                                m_extra,
         
     | 
| 270 | 
         
            +
                                current_quants=num_quanted,
         
     | 
| 271 | 
         
            +
                                max_quants=num_layers_to_quantize,
         
     | 
| 272 | 
         
            +
                                quantization_dtype=quantization_dtype,
         
     | 
| 273 | 
         
            +
                            )
         
     | 
| 274 | 
         
             
                return flow_model
         
     | 
    	
        util.py
    CHANGED
    
    | 
         @@ -18,6 +18,13 @@ class ModelVersion(StrEnum): 
     | 
|
| 18 | 
         
             
                flux_schnell = "flux-schnell"
         
     | 
| 19 | 
         | 
| 20 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 21 | 
         
             
            class ModelSpec(BaseModel):
         
     | 
| 22 | 
         
             
                version: ModelVersion
         
     | 
| 23 | 
         
             
                params: FluxParams
         
     | 
| 
         @@ -39,6 +46,10 @@ class ModelSpec(BaseModel): 
     | 
|
| 39 | 
         
             
                quantize_extras: bool = False
         
     | 
| 40 | 
         
             
                compile_extras: bool = False
         
     | 
| 41 | 
         
             
                compile_blocks: bool = False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 42 | 
         | 
| 43 | 
         
             
                model_config: ConfigDict = {
         
     | 
| 44 | 
         
             
                    "arbitrary_types_allowed": True,
         
     | 
| 
         @@ -199,13 +210,15 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: 
     | 
|
| 199 | 
         
             
                    "openai/clip-vit-large-patch14",
         
     | 
| 200 | 
         
             
                    max_length=77,
         
     | 
| 201 | 
         
             
                    torch_dtype=into_dtype(config.text_enc_dtype),
         
     | 
| 202 | 
         
            -
                    device=into_device(config.text_enc_device),
         
     | 
| 
         | 
|
| 203 | 
         
             
                )
         
     | 
| 204 | 
         
             
                t5 = HFEmbedder(
         
     | 
| 205 | 
         
             
                    config.text_enc_path,
         
     | 
| 206 | 
         
             
                    max_length=config.text_enc_max_length,
         
     | 
| 207 | 
         
             
                    torch_dtype=into_dtype(config.text_enc_dtype),
         
     | 
| 208 | 
         
             
                    device=into_device(config.text_enc_device).index or 0,
         
     | 
| 
         | 
|
| 209 | 
         
             
                )
         
     | 
| 210 | 
         
             
                return clip, t5
         
     | 
| 211 | 
         | 
| 
         @@ -213,12 +226,22 @@ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: 
     | 
|
| 213 | 
         
             
            def load_autoencoder(config: ModelSpec) -> AutoEncoder:
         
     | 
| 214 | 
         
             
                ckpt_path = config.ae_path
         
     | 
| 215 | 
         
             
                with torch.device("meta" if ckpt_path is not None else config.ae_device):
         
     | 
| 216 | 
         
            -
                    ae = AutoEncoder(config.ae_params)
         
     | 
| 217 | 
         | 
| 218 | 
         
             
                if ckpt_path is not None:
         
     | 
| 219 | 
         
             
                    sd = load_sft(ckpt_path, device=str(config.ae_device))
         
     | 
| 220 | 
         
             
                    missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 221 | 
         
             
                    print_load_warning(missing, unexpected)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 222 | 
         
             
                return ae
         
     | 
| 223 | 
         | 
| 224 | 
         | 
| 
         | 
|
| 18 | 
         
             
                flux_schnell = "flux-schnell"
         
     | 
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
            +
            class QuantizationDtype(StrEnum):
         
     | 
| 22 | 
         
            +
                qfloat8 = "qfloat8"
         
     | 
| 23 | 
         
            +
                qint2 = "qint2"
         
     | 
| 24 | 
         
            +
                qint4 = "qint4"
         
     | 
| 25 | 
         
            +
                qint8 = "qint8"
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
             
            class ModelSpec(BaseModel):
         
     | 
| 29 | 
         
             
                version: ModelVersion
         
     | 
| 30 | 
         
             
                params: FluxParams
         
     | 
| 
         | 
|
| 46 | 
         
             
                quantize_extras: bool = False
         
     | 
| 47 | 
         
             
                compile_extras: bool = False
         
     | 
| 48 | 
         
             
                compile_blocks: bool = False
         
     | 
| 49 | 
         
            +
                flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
         
     | 
| 50 | 
         
            +
                text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
         
     | 
| 51 | 
         
            +
                ae_quantization_dtype: Optional[QuantizationDtype] = None
         
     | 
| 52 | 
         
            +
                clip_quantization_dtype: Optional[QuantizationDtype] = None
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                model_config: ConfigDict = {
         
     | 
| 55 | 
         
             
                    "arbitrary_types_allowed": True,
         
     | 
| 
         | 
|
| 210 | 
         
             
                    "openai/clip-vit-large-patch14",
         
     | 
| 211 | 
         
             
                    max_length=77,
         
     | 
| 212 | 
         
             
                    torch_dtype=into_dtype(config.text_enc_dtype),
         
     | 
| 213 | 
         
            +
                    device=into_device(config.text_enc_device).index or 0,
         
     | 
| 214 | 
         
            +
                    quantization_dtype=config.clip_quantization_dtype,
         
     | 
| 215 | 
         
             
                )
         
     | 
| 216 | 
         
             
                t5 = HFEmbedder(
         
     | 
| 217 | 
         
             
                    config.text_enc_path,
         
     | 
| 218 | 
         
             
                    max_length=config.text_enc_max_length,
         
     | 
| 219 | 
         
             
                    torch_dtype=into_dtype(config.text_enc_dtype),
         
     | 
| 220 | 
         
             
                    device=into_device(config.text_enc_device).index or 0,
         
     | 
| 221 | 
         
            +
                    quantization_dtype=config.text_enc_quantization_dtype,
         
     | 
| 222 | 
         
             
                )
         
     | 
| 223 | 
         
             
                return clip, t5
         
     | 
| 224 | 
         | 
| 
         | 
|
| 226 | 
         
             
            def load_autoencoder(config: ModelSpec) -> AutoEncoder:
         
     | 
| 227 | 
         
             
                ckpt_path = config.ae_path
         
     | 
| 228 | 
         
             
                with torch.device("meta" if ckpt_path is not None else config.ae_device):
         
     | 
| 229 | 
         
            +
                    ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype))
         
     | 
| 230 | 
         | 
| 231 | 
         
             
                if ckpt_path is not None:
         
     | 
| 232 | 
         
             
                    sd = load_sft(ckpt_path, device=str(config.ae_device))
         
     | 
| 233 | 
         
             
                    missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 234 | 
         
             
                    print_load_warning(missing, unexpected)
         
     | 
| 235 | 
         
            +
                if config.ae_quantization_dtype is not None:
         
     | 
| 236 | 
         
            +
                    from quantize_swap_and_dispatch import _full_quant, into_qtype
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    ae.to(into_device(config.ae_device))
         
     | 
| 239 | 
         
            +
                    _full_quant(
         
     | 
| 240 | 
         
            +
                        ae,
         
     | 
| 241 | 
         
            +
                        max_quants=8000,
         
     | 
| 242 | 
         
            +
                        current_quants=0,
         
     | 
| 243 | 
         
            +
                        quantization_dtype=into_qtype(config.ae_quantization_dtype),
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
             
                return ae
         
     | 
| 246 | 
         | 
| 247 | 
         |