|
from loguru import logger |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import math |
|
from torch.compiler import is_compiling |
|
from torch import __version__ |
|
from torch.version import cuda |
|
|
|
from modules.flux_model import Modulation |
|
|
|
IS_TORCH_2_4 = __version__ < (2, 4, 9) |
|
LT_TORCH_2_4 = __version__ < (2, 4) |
|
if LT_TORCH_2_4: |
|
if not hasattr(torch, "_scaled_mm"): |
|
raise RuntimeError( |
|
"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later." |
|
) |
|
CUDA_VERSION = float(cuda) if cuda else 0 |
|
if CUDA_VERSION < 12.4: |
|
raise RuntimeError( |
|
f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}." |
|
) |
|
try: |
|
from cublas_ops import CublasLinear |
|
except ImportError: |
|
CublasLinear = type(None) |
|
|
|
|
|
class F8Linear(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=torch.float16, |
|
float8_dtype=torch.float8_e4m3fn, |
|
float_weight: torch.Tensor = None, |
|
float_bias: torch.Tensor = None, |
|
num_scale_trials: int = 12, |
|
input_float8_dtype=torch.float8_e5m2, |
|
) -> None: |
|
super().__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.float8_dtype = float8_dtype |
|
self.input_float8_dtype = input_float8_dtype |
|
self.input_scale_initialized = False |
|
self.weight_initialized = False |
|
self.max_value = torch.finfo(self.float8_dtype).max |
|
self.input_max_value = torch.finfo(self.input_float8_dtype).max |
|
factory_kwargs = {"dtype": dtype, "device": device} |
|
if float_weight is None: |
|
self.weight = nn.Parameter( |
|
torch.empty((out_features, in_features), **factory_kwargs) |
|
) |
|
else: |
|
self.weight = nn.Parameter( |
|
float_weight, requires_grad=float_weight.requires_grad |
|
) |
|
if float_bias is None: |
|
if bias: |
|
self.bias = nn.Parameter( |
|
torch.empty(out_features, **factory_kwargs), |
|
) |
|
else: |
|
self.register_parameter("bias", None) |
|
else: |
|
self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad) |
|
self.num_scale_trials = num_scale_trials |
|
self.input_amax_trials = torch.zeros( |
|
num_scale_trials, requires_grad=False, device=device, dtype=torch.float32 |
|
) |
|
self.trial_index = 0 |
|
self.register_buffer("scale", None) |
|
self.register_buffer( |
|
"input_scale", |
|
None, |
|
) |
|
self.register_buffer( |
|
"float8_data", |
|
None, |
|
) |
|
self.scale_reciprocal = self.register_buffer("scale_reciprocal", None) |
|
self.input_scale_reciprocal = self.register_buffer( |
|
"input_scale_reciprocal", None |
|
) |
|
|
|
def _load_from_state_dict( |
|
self, |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
sd = {k.replace(prefix, ""): v for k, v in state_dict.items()} |
|
if "weight" in sd: |
|
if ( |
|
"float8_data" not in sd |
|
or sd["float8_data"] is None |
|
and sd["weight"].shape == (self.out_features, self.in_features) |
|
): |
|
|
|
self._parameters["weight"] = nn.Parameter( |
|
sd["weight"], requires_grad=False |
|
) |
|
if "bias" in sd: |
|
self._parameters["bias"] = nn.Parameter( |
|
sd["bias"], requires_grad=False |
|
) |
|
self.quantize_weight() |
|
elif sd["float8_data"].shape == ( |
|
self.out_features, |
|
self.in_features, |
|
) and sd["weight"] == torch.zeros_like(sd["weight"]): |
|
w = sd["weight"] |
|
|
|
self._buffers["float8_data"] = sd["float8_data"] |
|
self._parameters["weight"] = nn.Parameter( |
|
torch.zeros( |
|
1, |
|
dtype=w.dtype, |
|
device=w.device, |
|
requires_grad=False, |
|
) |
|
) |
|
if "bias" in sd: |
|
self._parameters["bias"] = nn.Parameter( |
|
sd["bias"], requires_grad=False |
|
) |
|
self.weight_initialized = True |
|
|
|
|
|
if all( |
|
key in sd |
|
for key in [ |
|
"scale", |
|
"input_scale", |
|
"scale_reciprocal", |
|
"input_scale_reciprocal", |
|
] |
|
): |
|
self.scale = sd["scale"].float() |
|
self.input_scale = sd["input_scale"].float() |
|
self.scale_reciprocal = sd["scale_reciprocal"].float() |
|
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float() |
|
self.input_scale_initialized = True |
|
self.trial_index = self.num_scale_trials |
|
elif "scale" in sd and "scale_reciprocal" in sd: |
|
self.scale = sd["scale"].float() |
|
self.input_scale = ( |
|
sd["input_scale"].float() if "input_scale" in sd else None |
|
) |
|
self.scale_reciprocal = sd["scale_reciprocal"].float() |
|
self.input_scale_reciprocal = ( |
|
sd["input_scale_reciprocal"].float() |
|
if "input_scale_reciprocal" in sd |
|
else None |
|
) |
|
self.input_scale_initialized = ( |
|
True if "input_scale" in sd else False |
|
) |
|
self.trial_index = ( |
|
self.num_scale_trials if "input_scale" in sd else 0 |
|
) |
|
self.input_amax_trials = torch.zeros( |
|
self.num_scale_trials, |
|
requires_grad=False, |
|
dtype=torch.float32, |
|
device=self.weight.device, |
|
) |
|
self.input_scale_initialized = False |
|
self.trial_index = 0 |
|
else: |
|
|
|
self.input_scale_initialized = False |
|
self.trial_index = 0 |
|
self.input_amax_trials = torch.zeros( |
|
self.num_scale_trials, requires_grad=False, dtype=torch.float32 |
|
) |
|
else: |
|
raise RuntimeError( |
|
f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}" |
|
) |
|
else: |
|
raise RuntimeError( |
|
"Weight tensor not found or has incorrect shape in state dict" |
|
) |
|
|
|
def quantize_weight(self): |
|
if self.weight_initialized: |
|
return |
|
amax = torch.max(torch.abs(self.weight.data)).float() |
|
self.scale = self.amax_to_scale(amax, self.max_value) |
|
self.float8_data = self.to_fp8_saturated( |
|
self.weight.data, self.scale, self.max_value |
|
).to(self.float8_dtype) |
|
self.scale_reciprocal = self.scale.reciprocal() |
|
self.weight.data = torch.zeros( |
|
1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False |
|
) |
|
self.weight_initialized = True |
|
|
|
def set_weight_tensor(self, tensor: torch.Tensor): |
|
self.weight.data = tensor |
|
self.weight_initialized = False |
|
self.quantize_weight() |
|
|
|
def amax_to_scale(self, amax, max_val): |
|
return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val) |
|
|
|
def to_fp8_saturated(self, x, scale, max_val): |
|
return (x * scale).clamp(-max_val, max_val) |
|
|
|
def quantize_input(self, x: torch.Tensor): |
|
if self.input_scale_initialized: |
|
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
|
self.input_float8_dtype |
|
) |
|
elif self.trial_index < self.num_scale_trials: |
|
|
|
amax = torch.max(torch.abs(x)).float() |
|
|
|
self.input_amax_trials[self.trial_index] = amax |
|
self.trial_index += 1 |
|
self.input_scale = self.amax_to_scale( |
|
self.input_amax_trials[: self.trial_index].max(), self.input_max_value |
|
) |
|
self.input_scale_reciprocal = self.input_scale.reciprocal() |
|
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
|
self.input_float8_dtype |
|
) |
|
else: |
|
self.input_scale = self.amax_to_scale( |
|
self.input_amax_trials.max(), self.input_max_value |
|
) |
|
self.input_scale_reciprocal = self.input_scale.reciprocal() |
|
self.input_scale_initialized = True |
|
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
|
self.input_float8_dtype |
|
) |
|
|
|
def reset_parameters(self) -> None: |
|
if self.weight_initialized: |
|
self.weight = nn.Parameter( |
|
torch.empty( |
|
(self.out_features, self.in_features), |
|
**{ |
|
"dtype": self.weight.dtype, |
|
"device": self.weight.device, |
|
}, |
|
) |
|
) |
|
self.weight_initialized = False |
|
self.input_scale_initialized = False |
|
self.trial_index = 0 |
|
self.input_amax_trials.zero_() |
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
if self.bias is not None: |
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
|
init.uniform_(self.bias, -bound, bound) |
|
self.quantize_weight() |
|
self.max_value = torch.finfo(self.float8_dtype).max |
|
self.input_max_value = torch.finfo(self.input_float8_dtype).max |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.input_scale_initialized or is_compiling(): |
|
x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
|
self.input_float8_dtype |
|
) |
|
else: |
|
x = self.quantize_input(x) |
|
|
|
prev_dims = x.shape[:-1] |
|
x = x.view(-1, self.in_features) |
|
|
|
|
|
out = torch._scaled_mm( |
|
x, |
|
self.float8_data.T, |
|
scale_a=self.input_scale_reciprocal, |
|
scale_b=self.scale_reciprocal, |
|
bias=self.bias, |
|
out_dtype=self.weight.dtype, |
|
use_fast_accum=True, |
|
) |
|
if IS_TORCH_2_4: |
|
out = out[0] |
|
out = out.view(*prev_dims, self.out_features) |
|
return out |
|
|
|
@classmethod |
|
def from_linear( |
|
cls, |
|
linear: nn.Linear, |
|
float8_dtype=torch.float8_e4m3fn, |
|
input_float8_dtype=torch.float8_e5m2, |
|
) -> "F8Linear": |
|
f8_lin = cls( |
|
in_features=linear.in_features, |
|
out_features=linear.out_features, |
|
bias=linear.bias is not None, |
|
device=linear.weight.device, |
|
dtype=linear.weight.dtype, |
|
float8_dtype=float8_dtype, |
|
float_weight=linear.weight.data, |
|
float_bias=(linear.bias.data if linear.bias is not None else None), |
|
input_float8_dtype=input_float8_dtype, |
|
) |
|
f8_lin.quantize_weight() |
|
return f8_lin |
|
|
|
|
|
@torch.inference_mode() |
|
def recursive_swap_linears( |
|
model: nn.Module, |
|
float8_dtype=torch.float8_e4m3fn, |
|
input_float8_dtype=torch.float8_e5m2, |
|
quantize_modulation: bool = True, |
|
ignore_keys: list[str] = [], |
|
) -> None: |
|
""" |
|
Recursively swaps all nn.Linear modules in the given model with F8Linear modules. |
|
|
|
This function traverses the model's structure and replaces each nn.Linear |
|
instance with an F8Linear instance, which uses 8-bit floating point |
|
quantization for weights. The original linear layer's weights are deleted |
|
after conversion to save memory. |
|
|
|
Args: |
|
model (nn.Module): The PyTorch model to modify. |
|
|
|
Note: |
|
This function modifies the model in-place. After calling this function, |
|
all linear layers in the model will be using 8-bit quantization. |
|
""" |
|
for name, child in model.named_children(): |
|
if name in ignore_keys: |
|
continue |
|
if isinstance(child, Modulation) and not quantize_modulation: |
|
continue |
|
if isinstance(child, nn.Linear) and not isinstance( |
|
child, (F8Linear, CublasLinear) |
|
): |
|
|
|
setattr( |
|
model, |
|
name, |
|
F8Linear.from_linear( |
|
child, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
), |
|
) |
|
del child |
|
else: |
|
recursive_swap_linears( |
|
child, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
quantize_modulation=quantize_modulation, |
|
ignore_keys=ignore_keys, |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def swap_to_cublaslinear(model: nn.Module): |
|
if CublasLinear == type(None): |
|
return |
|
for name, child in model.named_children(): |
|
if isinstance(child, nn.Linear) and not isinstance( |
|
child, (F8Linear, CublasLinear) |
|
): |
|
cublas_lin = CublasLinear( |
|
child.in_features, |
|
child.out_features, |
|
bias=child.bias is not None, |
|
dtype=child.weight.dtype, |
|
device=child.weight.device, |
|
) |
|
cublas_lin.weight.data = child.weight.clone().detach() |
|
cublas_lin.bias.data = child.bias.clone().detach() |
|
setattr(model, name, cublas_lin) |
|
del child |
|
else: |
|
swap_to_cublaslinear(child) |
|
|
|
|
|
@torch.inference_mode() |
|
def quantize_flow_transformer_and_dispatch_float8( |
|
flow_model: nn.Module, |
|
device=torch.device("cuda"), |
|
float8_dtype=torch.float8_e4m3fn, |
|
input_float8_dtype=torch.float8_e5m2, |
|
offload_flow=False, |
|
swap_linears_with_cublaslinear=True, |
|
flow_dtype=torch.float16, |
|
quantize_modulation: bool = True, |
|
quantize_flow_embedder_layers: bool = True, |
|
) -> nn.Module: |
|
""" |
|
Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device. |
|
|
|
Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes. |
|
|
|
Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory. |
|
|
|
After dispatching, if offload_flow is True, offloads the model to cpu. |
|
|
|
if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs. |
|
Otherwise will skip the cublaslinear swap. |
|
|
|
For added extra precision, you can set quantize_flow_embedder_layers to False, |
|
this helps maintain the output quality of the flow transformer moreso than fully quantizing, |
|
at the expense of ~512MB more VRAM usage. |
|
|
|
For added extra precision, you can set quantize_modulation to False, |
|
this helps maintain the output quality of the flow transformer moreso than fully quantizing, |
|
at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers. |
|
""" |
|
for module in flow_model.double_blocks: |
|
module.to(device) |
|
module.eval() |
|
recursive_swap_linears( |
|
module, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
quantize_modulation=quantize_modulation, |
|
) |
|
torch.cuda.empty_cache() |
|
for module in flow_model.single_blocks: |
|
module.to(device) |
|
module.eval() |
|
recursive_swap_linears( |
|
module, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
quantize_modulation=quantize_modulation, |
|
) |
|
torch.cuda.empty_cache() |
|
to_gpu_extras = [ |
|
"vector_in", |
|
"img_in", |
|
"txt_in", |
|
"time_in", |
|
"guidance_in", |
|
"final_layer", |
|
"pe_embedder", |
|
] |
|
for module in to_gpu_extras: |
|
m_extra = getattr(flow_model, module) |
|
if m_extra is None: |
|
continue |
|
m_extra.to(device) |
|
m_extra.eval() |
|
if isinstance(m_extra, nn.Linear) and not isinstance( |
|
m_extra, (F8Linear, CublasLinear) |
|
): |
|
if quantize_flow_embedder_layers: |
|
setattr( |
|
flow_model, |
|
module, |
|
F8Linear.from_linear( |
|
m_extra, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
), |
|
) |
|
del m_extra |
|
elif module != "final_layer": |
|
if quantize_flow_embedder_layers: |
|
recursive_swap_linears( |
|
m_extra, |
|
float8_dtype=float8_dtype, |
|
input_float8_dtype=input_float8_dtype, |
|
quantize_modulation=quantize_modulation, |
|
) |
|
torch.cuda.empty_cache() |
|
if ( |
|
swap_linears_with_cublaslinear |
|
and flow_dtype == torch.float16 |
|
and CublasLinear != type(None) |
|
): |
|
swap_to_cublaslinear(flow_model) |
|
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16: |
|
logger.warning("Skipping cublas linear swap because flow_dtype is not float16") |
|
if offload_flow: |
|
flow_model.to("cpu") |
|
torch.cuda.empty_cache() |
|
return flow_model |
|
|