File size: 18,623 Bytes
37bd8c1 28dec30 b6617b1 37bd8c1 28dec30 340f0a0 b6617b1 28dec30 2f2c44c 28dec30 2f2c44c 28dec30 37bd8c1 28dec30 2f2c44c 3ddaa67 2f2c44c 3ddaa67 2f2c44c 28dec30 d45a331 28dec30 d45a331 28dec30 37bd8c1 d45a331 28dec30 d45a331 28dec30 d45a331 28dec30 d45a331 28dec30 d45a331 28dec30 d45a331 28dec30 d45a331 28dec30 0f3134f 28dec30 d45a331 28dec30 b6617b1 28dec30 b6617b1 28dec30 b6617b1 28dec30 0f3134f 28dec30 37bd8c1 28dec30 37bd8c1 3ddaa67 0f3134f 28dec30 3ddaa67 37bd8c1 28dec30 37bd8c1 3ddaa67 37bd8c1 56c313c 37bd8c1 28dec30 6d0762c 28dec30 37bd8c1 6d0762c 28dec30 0f3134f 37bd8c1 28dec30 0f3134f 28dec30 37bd8c1 28dec30 0f3134f 28dec30 37bd8c1 28dec30 1f9e684 28dec30 1f9e684 28dec30 1f9e684 28dec30 00f5d2c 56c313c 00f5d2c 37bd8c1 28dec30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 |
from loguru import logger
import torch
import torch.nn as nn
from torch.nn import init
import math
from torch.compiler import is_compiling
from torch import __version__
from torch.version import cuda
from modules.flux_model import Modulation
IS_TORCH_2_4 = __version__ < (2, 4, 9)
LT_TORCH_2_4 = __version__ < (2, 4)
if LT_TORCH_2_4:
if not hasattr(torch, "_scaled_mm"):
raise RuntimeError(
"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
)
CUDA_VERSION = float(cuda) if cuda else 0
if CUDA_VERSION < 12.4:
raise RuntimeError(
f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
)
try:
from cublas_ops import CublasLinear
except ImportError:
CublasLinear = type(None)
class F8Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=torch.float16,
float8_dtype=torch.float8_e4m3fn,
float_weight: torch.Tensor = None,
float_bias: torch.Tensor = None,
num_scale_trials: int = 12,
input_float8_dtype=torch.float8_e5m2,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.float8_dtype = float8_dtype
self.input_float8_dtype = input_float8_dtype
self.input_scale_initialized = False
self.weight_initialized = False
self.max_value = torch.finfo(self.float8_dtype).max
self.input_max_value = torch.finfo(self.input_float8_dtype).max
factory_kwargs = {"dtype": dtype, "device": device}
if float_weight is None:
self.weight = nn.Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
else:
self.weight = nn.Parameter(
float_weight, requires_grad=float_weight.requires_grad
)
if float_bias is None:
if bias:
self.bias = nn.Parameter(
torch.empty(out_features, **factory_kwargs),
)
else:
self.register_parameter("bias", None)
else:
self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad)
self.num_scale_trials = num_scale_trials
self.input_amax_trials = torch.zeros(
num_scale_trials, requires_grad=False, device=device, dtype=torch.float32
)
self.trial_index = 0
self.register_buffer("scale", None)
self.register_buffer(
"input_scale",
None,
)
self.register_buffer(
"float8_data",
None,
)
self.scale_reciprocal = self.register_buffer("scale_reciprocal", None)
self.input_scale_reciprocal = self.register_buffer(
"input_scale_reciprocal", None
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
sd = {k.replace(prefix, ""): v for k, v in state_dict.items()}
if "weight" in sd:
if (
"float8_data" not in sd
or sd["float8_data"] is None
and sd["weight"].shape == (self.out_features, self.in_features)
):
# Initialize as if it's an F8Linear that needs to be quantized
self._parameters["weight"] = nn.Parameter(
sd["weight"], requires_grad=False
)
if "bias" in sd:
self._parameters["bias"] = nn.Parameter(
sd["bias"], requires_grad=False
)
self.quantize_weight()
elif sd["float8_data"].shape == (
self.out_features,
self.in_features,
) and sd["weight"] == torch.zeros_like(sd["weight"]):
w = sd["weight"]
# Set the init values as if it's already quantized float8_data
self._buffers["float8_data"] = sd["float8_data"]
self._parameters["weight"] = nn.Parameter(
torch.zeros(
1,
dtype=w.dtype,
device=w.device,
requires_grad=False,
)
)
if "bias" in sd:
self._parameters["bias"] = nn.Parameter(
sd["bias"], requires_grad=False
)
self.weight_initialized = True
# Check if scales and reciprocals are initialized
if all(
key in sd
for key in [
"scale",
"input_scale",
"scale_reciprocal",
"input_scale_reciprocal",
]
):
self.scale = sd["scale"].float()
self.input_scale = sd["input_scale"].float()
self.scale_reciprocal = sd["scale_reciprocal"].float()
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
self.input_scale_initialized = True
self.trial_index = self.num_scale_trials
elif "scale" in sd and "scale_reciprocal" in sd:
self.scale = sd["scale"].float()
self.input_scale = (
sd["input_scale"].float() if "input_scale" in sd else None
)
self.scale_reciprocal = sd["scale_reciprocal"].float()
self.input_scale_reciprocal = (
sd["input_scale_reciprocal"].float()
if "input_scale_reciprocal" in sd
else None
)
self.input_scale_initialized = (
True if "input_scale" in sd else False
)
self.trial_index = (
self.num_scale_trials if "input_scale" in sd else 0
)
self.input_amax_trials = torch.zeros(
self.num_scale_trials,
requires_grad=False,
dtype=torch.float32,
device=self.weight.device,
)
self.input_scale_initialized = False
self.trial_index = 0
else:
# If scales are not initialized, reset trials
self.input_scale_initialized = False
self.trial_index = 0
self.input_amax_trials = torch.zeros(
self.num_scale_trials, requires_grad=False, dtype=torch.float32
)
else:
raise RuntimeError(
f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}"
)
else:
raise RuntimeError(
"Weight tensor not found or has incorrect shape in state dict"
)
def quantize_weight(self):
if self.weight_initialized:
return
amax = torch.max(torch.abs(self.weight.data)).float()
self.scale = self.amax_to_scale(amax, self.max_value)
self.float8_data = self.to_fp8_saturated(
self.weight.data, self.scale, self.max_value
).to(self.float8_dtype)
self.scale_reciprocal = self.scale.reciprocal()
self.weight.data = torch.zeros(
1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
)
self.weight_initialized = True
def set_weight_tensor(self, tensor: torch.Tensor):
self.weight.data = tensor
self.weight_initialized = False
self.quantize_weight()
def amax_to_scale(self, amax, max_val):
return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val)
def to_fp8_saturated(self, x, scale, max_val):
return (x * scale).clamp(-max_val, max_val)
def quantize_input(self, x: torch.Tensor):
if self.input_scale_initialized:
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
self.input_float8_dtype
)
elif self.trial_index < self.num_scale_trials:
amax = torch.max(torch.abs(x)).float()
self.input_amax_trials[self.trial_index] = amax
self.trial_index += 1
self.input_scale = self.amax_to_scale(
self.input_amax_trials[: self.trial_index].max(), self.input_max_value
)
self.input_scale_reciprocal = self.input_scale.reciprocal()
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
self.input_float8_dtype
)
else:
self.input_scale = self.amax_to_scale(
self.input_amax_trials.max(), self.input_max_value
)
self.input_scale_reciprocal = self.input_scale.reciprocal()
self.input_scale_initialized = True
return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
self.input_float8_dtype
)
def reset_parameters(self) -> None:
if self.weight_initialized:
self.weight = nn.Parameter(
torch.empty(
(self.out_features, self.in_features),
**{
"dtype": self.weight.dtype,
"device": self.weight.device,
},
)
)
self.weight_initialized = False
self.input_scale_initialized = False
self.trial_index = 0
self.input_amax_trials.zero_()
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
self.quantize_weight()
self.max_value = torch.finfo(self.float8_dtype).max
self.input_max_value = torch.finfo(self.input_float8_dtype).max
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.input_scale_initialized or is_compiling():
x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
self.input_float8_dtype
)
else:
x = self.quantize_input(x)
prev_dims = x.shape[:-1]
x = x.view(-1, self.in_features)
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
out = torch._scaled_mm(
x,
self.float8_data.T,
scale_a=self.input_scale_reciprocal,
scale_b=self.scale_reciprocal,
bias=self.bias,
out_dtype=self.weight.dtype,
use_fast_accum=True,
)
if IS_TORCH_2_4:
out = out[0]
out = out.view(*prev_dims, self.out_features)
return out
@classmethod
def from_linear(
cls,
linear: nn.Linear,
float8_dtype=torch.float8_e4m3fn,
input_float8_dtype=torch.float8_e5m2,
) -> "F8Linear":
f8_lin = cls(
in_features=linear.in_features,
out_features=linear.out_features,
bias=linear.bias is not None,
device=linear.weight.device,
dtype=linear.weight.dtype,
float8_dtype=float8_dtype,
float_weight=linear.weight.data,
float_bias=(linear.bias.data if linear.bias is not None else None),
input_float8_dtype=input_float8_dtype,
)
f8_lin.quantize_weight()
return f8_lin
@torch.inference_mode()
def recursive_swap_linears(
model: nn.Module,
float8_dtype=torch.float8_e4m3fn,
input_float8_dtype=torch.float8_e5m2,
quantize_modulation: bool = True,
ignore_keys: list[str] = [],
) -> None:
"""
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
This function traverses the model's structure and replaces each nn.Linear
instance with an F8Linear instance, which uses 8-bit floating point
quantization for weights. The original linear layer's weights are deleted
after conversion to save memory.
Args:
model (nn.Module): The PyTorch model to modify.
Note:
This function modifies the model in-place. After calling this function,
all linear layers in the model will be using 8-bit quantization.
"""
for name, child in model.named_children():
if name in ignore_keys:
continue
if isinstance(child, Modulation) and not quantize_modulation:
continue
if isinstance(child, nn.Linear) and not isinstance(
child, (F8Linear, CublasLinear)
):
setattr(
model,
name,
F8Linear.from_linear(
child,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
),
)
del child
else:
recursive_swap_linears(
child,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
quantize_modulation=quantize_modulation,
ignore_keys=ignore_keys,
)
@torch.inference_mode()
def swap_to_cublaslinear(model: nn.Module):
if CublasLinear == type(None):
return
for name, child in model.named_children():
if isinstance(child, nn.Linear) and not isinstance(
child, (F8Linear, CublasLinear)
):
cublas_lin = CublasLinear(
child.in_features,
child.out_features,
bias=child.bias is not None,
dtype=child.weight.dtype,
device=child.weight.device,
)
cublas_lin.weight.data = child.weight.clone().detach()
cublas_lin.bias.data = child.bias.clone().detach()
setattr(model, name, cublas_lin)
del child
else:
swap_to_cublaslinear(child)
@torch.inference_mode()
def quantize_flow_transformer_and_dispatch_float8(
flow_model: nn.Module,
device=torch.device("cuda"),
float8_dtype=torch.float8_e4m3fn,
input_float8_dtype=torch.float8_e5m2,
offload_flow=False,
swap_linears_with_cublaslinear=True,
flow_dtype=torch.float16,
quantize_modulation: bool = True,
quantize_flow_embedder_layers: bool = True,
) -> nn.Module:
"""
Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes.
Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
After dispatching, if offload_flow is True, offloads the model to cpu.
if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs.
Otherwise will skip the cublaslinear swap.
For added extra precision, you can set quantize_flow_embedder_layers to False,
this helps maintain the output quality of the flow transformer moreso than fully quantizing,
at the expense of ~512MB more VRAM usage.
For added extra precision, you can set quantize_modulation to False,
this helps maintain the output quality of the flow transformer moreso than fully quantizing,
at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers.
"""
for module in flow_model.double_blocks:
module.to(device)
module.eval()
recursive_swap_linears(
module,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
quantize_modulation=quantize_modulation,
)
torch.cuda.empty_cache()
for module in flow_model.single_blocks:
module.to(device)
module.eval()
recursive_swap_linears(
module,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
quantize_modulation=quantize_modulation,
)
torch.cuda.empty_cache()
to_gpu_extras = [
"vector_in",
"img_in",
"txt_in",
"time_in",
"guidance_in",
"final_layer",
"pe_embedder",
]
for module in to_gpu_extras:
m_extra = getattr(flow_model, module)
if m_extra is None:
continue
m_extra.to(device)
m_extra.eval()
if isinstance(m_extra, nn.Linear) and not isinstance(
m_extra, (F8Linear, CublasLinear)
):
if quantize_flow_embedder_layers:
setattr(
flow_model,
module,
F8Linear.from_linear(
m_extra,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
),
)
del m_extra
elif module != "final_layer":
if quantize_flow_embedder_layers:
recursive_swap_linears(
m_extra,
float8_dtype=float8_dtype,
input_float8_dtype=input_float8_dtype,
quantize_modulation=quantize_modulation,
)
torch.cuda.empty_cache()
if (
swap_linears_with_cublaslinear
and flow_dtype == torch.float16
and CublasLinear != type(None)
):
swap_to_cublaslinear(flow_model)
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
if offload_flow:
flow_model.to("cpu")
torch.cuda.empty_cache()
return flow_model
|