Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
import datetime | |
import shutil | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import time | |
from pathlib import Path | |
import click | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fish_speech.models.text2semantic.llama import find_multiple | |
from tools.llama.generate import load_model | |
##### Quantization Primitives ###### | |
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): | |
# assumes symmetric quantization | |
# assumes axis == 0 | |
# assumes dense memory format | |
# TODO(future): relax ^ as needed | |
# default setup for affine quantization of activations | |
eps = torch.finfo(torch.float32).eps | |
# get min and max | |
min_val, max_val = torch.aminmax(x, dim=1) | |
# calculate scales and zero_points based on min and max | |
# reference: https://fburl.com/code/srbiybme | |
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | |
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | |
device = min_val_neg.device | |
# reference: https://fburl.com/code/4wll53rk | |
max_val_pos = torch.max(-min_val_neg, max_val_pos) | |
scales = max_val_pos / (float(quant_max - quant_min) / 2) | |
# ensure scales is the same dtype as the original tensor | |
scales = torch.clamp(scales, min=eps).to(x.dtype) | |
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) | |
# quantize based on qmin/qmax/scales/zp | |
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 | |
x_div = x / scales.unsqueeze(-1) | |
x_round = torch.round(x_div) | |
x_zp = x_round + zero_points.unsqueeze(-1) | |
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) | |
return quant, scales, zero_points | |
def get_group_qparams(w, n_bit=4, groupsize=128): | |
# needed for GPTQ with padding | |
if groupsize > w.shape[-1]: | |
groupsize = w.shape[-1] | |
assert groupsize > 1 | |
assert w.shape[-1] % groupsize == 0 | |
assert w.dim() == 2 | |
to_quant = w.reshape(-1, groupsize) | |
assert torch.isnan(to_quant).sum() == 0 | |
max_val = to_quant.amax(dim=1, keepdim=True) | |
min_val = to_quant.amin(dim=1, keepdim=True) | |
max_int = 2**n_bit - 1 | |
scales = (max_val - min_val).clamp(min=1e-6) / max_int | |
zeros = min_val + scales * (2 ** (n_bit - 1)) | |
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( | |
torch.bfloat16 | |
).reshape(w.shape[0], -1) | |
def pack_scales_and_zeros(scales, zeros): | |
assert scales.shape == zeros.shape | |
assert scales.dtype == torch.bfloat16 | |
assert zeros.dtype == torch.bfloat16 | |
return ( | |
torch.cat( | |
[ | |
scales.reshape(scales.size(0), scales.size(1), 1), | |
zeros.reshape(zeros.size(0), zeros.size(1), 1), | |
], | |
2, | |
) | |
.transpose(0, 1) | |
.contiguous() | |
) | |
def unpack_scales_and_zeros(scales_and_zeros): | |
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 | |
assert scales_and_zeros.dtype == torch.float | |
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) | |
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): | |
assert groupsize > 1 | |
# needed for GPTQ single column quantize | |
if groupsize > w.shape[-1] and scales.shape[-1] == 1: | |
groupsize = w.shape[-1] | |
assert w.shape[-1] % groupsize == 0 | |
assert w.dim() == 2 | |
to_quant = w.reshape(-1, groupsize) | |
assert torch.isnan(to_quant).sum() == 0 | |
scales = scales.reshape(-1, 1) | |
zeros = zeros.reshape(-1, 1) | |
min_val = zeros - scales * (2 ** (n_bit - 1)) | |
max_int = 2**n_bit - 1 | |
min_int = 0 | |
w_int32 = ( | |
to_quant.sub(min_val) | |
.div(scales) | |
.round() | |
.clamp_(min_int, max_int) | |
.to(torch.int32) | |
.reshape_as(w) | |
) | |
return w_int32 | |
def group_quantize_tensor(w, n_bit=4, groupsize=128): | |
scales, zeros = get_group_qparams(w, n_bit, groupsize) | |
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) | |
scales_and_zeros = pack_scales_and_zeros(scales, zeros) | |
return w_int32, scales_and_zeros | |
def group_dequantize_tensor_from_qparams( | |
w_int32, scales, zeros, n_bit=4, groupsize=128 | |
): | |
assert groupsize > 1 | |
# needed for GPTQ single column dequantize | |
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: | |
groupsize = w_int32.shape[-1] | |
assert w_int32.shape[-1] % groupsize == 0 | |
assert w_int32.dim() == 2 | |
w_int32_grouped = w_int32.reshape(-1, groupsize) | |
scales = scales.reshape(-1, 1) | |
zeros = zeros.reshape(-1, 1) | |
w_dq = ( | |
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) | |
) | |
return w_dq | |
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): | |
scales, zeros = unpack_scales_and_zeros(scales_and_zeros) | |
return group_dequantize_tensor_from_qparams( | |
w_int32, scales, zeros, n_bit, groupsize | |
) | |
class QuantHandler: | |
def __init__(self, mod): | |
self.mod = mod | |
def create_quantized_state_dict(self) -> "StateDict": | |
pass | |
def convert_for_runtime(self) -> "nn.Module": | |
pass | |
##### Weight-only int8 per-channel quantized code ###### | |
def replace_linear_weight_only_int8_per_channel(module): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Linear): | |
setattr( | |
module, | |
name, | |
WeightOnlyInt8Linear(child.in_features, child.out_features), | |
) | |
else: | |
replace_linear_weight_only_int8_per_channel(child) | |
class WeightOnlyInt8QuantHandler: | |
def __init__(self, mod): | |
self.mod = mod | |
def create_quantized_state_dict(self): | |
cur_state_dict = self.mod.state_dict() | |
for fqn, mod in self.mod.named_modules(): | |
if isinstance(mod, torch.nn.Linear): | |
int8_weight, scales, _ = dynamically_quantize_per_channel( | |
mod.weight.float(), -128, 127, torch.int8 | |
) | |
cur_state_dict[f"{fqn}.weight"] = int8_weight | |
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) | |
return cur_state_dict | |
def convert_for_runtime(self): | |
replace_linear_weight_only_int8_per_channel(self.mod) | |
return self.mod | |
class WeightOnlyInt8Linear(torch.nn.Module): | |
__constants__ = ["in_features", "out_features"] | |
in_features: int | |
out_features: int | |
weight: torch.Tensor | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias: bool = True, | |
device=None, | |
dtype=None, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.register_buffer( | |
"weight", torch.empty((out_features, in_features), dtype=torch.int8) | |
) | |
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales | |
##### weight only int4 per channel groupwise quantized code ###### | |
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): | |
weight_int32, scales_and_zeros = group_quantize_tensor( | |
weight_bf16, n_bit=4, groupsize=groupsize | |
) | |
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( | |
weight_int32, inner_k_tiles | |
) | |
return weight_int4pack, scales_and_zeros | |
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): | |
origin_x_size = x.size() | |
x = x.reshape(-1, origin_x_size[-1]) | |
c = torch.ops.aten._weight_int4pack_mm( | |
x, weight_int4pack, groupsize, scales_and_zeros | |
) | |
new_shape = origin_x_size[:-1] + (out_features,) | |
c = c.reshape(new_shape) | |
return c | |
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): | |
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 | |
def replace_linear_int4(module, groupsize, inner_k_tiles, padding): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Linear): | |
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): | |
setattr( | |
module, | |
name, | |
WeightOnlyInt4Linear( | |
child.in_features, | |
child.out_features, | |
bias=False, | |
groupsize=groupsize, | |
inner_k_tiles=inner_k_tiles, | |
padding=False, | |
), | |
) | |
elif padding: | |
setattr( | |
module, | |
name, | |
WeightOnlyInt4Linear( | |
child.in_features, | |
child.out_features, | |
bias=False, | |
groupsize=groupsize, | |
inner_k_tiles=inner_k_tiles, | |
padding=True, | |
), | |
) | |
else: | |
replace_linear_int4(child, groupsize, inner_k_tiles, padding) | |
class WeightOnlyInt4QuantHandler: | |
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): | |
self.mod = mod | |
self.groupsize = groupsize | |
self.inner_k_tiles = inner_k_tiles | |
self.padding = padding | |
assert groupsize in [32, 64, 128, 256] | |
assert inner_k_tiles in [2, 4, 8] | |
def create_quantized_state_dict(self): | |
cur_state_dict = self.mod.state_dict() | |
for fqn, mod in self.mod.named_modules(): | |
if isinstance(mod, torch.nn.Linear): | |
assert not mod.bias | |
out_features = mod.out_features | |
in_features = mod.in_features | |
assert out_features % 8 == 0, "require out_features % 8 == 0" | |
print(f"linear: {fqn}, in={in_features}, out={out_features}") | |
weight = mod.weight.data | |
if not _check_linear_int4_k( | |
in_features, self.groupsize, self.inner_k_tiles | |
): | |
if self.padding: | |
import torch.nn.functional as F | |
print( | |
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" | |
) | |
padded_in_features = find_multiple(in_features, 1024) | |
weight = F.pad( | |
weight, pad=(0, padded_in_features - in_features) | |
) | |
else: | |
print( | |
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " | |
+ "and that groupsize and inner_k_tiles*16 evenly divide into it" | |
) | |
continue | |
( | |
weight_int4pack, | |
scales_and_zeros, | |
) = prepare_int4_weight_and_scales_and_zeros( | |
weight.to(torch.bfloat16).to("cuda"), | |
self.groupsize, | |
self.inner_k_tiles, | |
) | |
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") | |
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") | |
return cur_state_dict | |
def convert_for_runtime(self): | |
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) | |
return self.mod | |
class WeightOnlyInt4Linear(torch.nn.Module): | |
__constants__ = ["in_features", "out_features"] | |
in_features: int | |
out_features: int | |
weight: torch.Tensor | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias=True, | |
device=None, | |
dtype=None, | |
groupsize: int = 128, | |
inner_k_tiles: int = 8, | |
padding: bool = True, | |
) -> None: | |
super().__init__() | |
self.padding = padding | |
if padding: | |
self.origin_in_features = in_features | |
in_features = find_multiple(in_features, 1024) | |
self.in_features = in_features | |
self.out_features = out_features | |
assert not bias, "require bias=False" | |
self.groupsize = groupsize | |
self.inner_k_tiles = inner_k_tiles | |
assert out_features % 8 == 0, "require out_features % 8 == 0" | |
assert ( | |
in_features % (inner_k_tiles * 16) == 0 | |
), "require in_features % (innerKTiles * 16) == 0" | |
self.register_buffer( | |
"weight", | |
torch.empty( | |
( | |
out_features // 8, | |
in_features // (inner_k_tiles * 16), | |
32, | |
inner_k_tiles // 2, | |
), | |
dtype=torch.int32, | |
), | |
) | |
self.register_buffer( | |
"scales_and_zeros", | |
torch.empty( | |
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16 | |
), | |
) | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
input = input.to(torch.bfloat16) | |
if self.padding: | |
import torch.nn.functional as F | |
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) | |
return linear_forward_int4( | |
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize | |
) | |
def generate_folder_name(): | |
now = datetime.datetime.now() | |
folder_name = now.strftime("%Y%m%d_%H%M%S") | |
return folder_name | |
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: | |
device = "cpu" | |
precision = torch.bfloat16 | |
print("Loading model ...") | |
t0 = time.time() | |
model, _ = load_model( | |
checkpoint_path=checkpoint_path, | |
device=device, | |
precision=precision, | |
compile=False, | |
) | |
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" | |
now = timestamp if timestamp != "None" else generate_folder_name() | |
if mode == "int8": | |
print( | |
"Quantizing model weights for int8 weight-only symmetric per-channel quantization" | |
) | |
quant_handler = WeightOnlyInt8QuantHandler(model) | |
quantized_state_dict = quant_handler.create_quantized_state_dict() | |
dir_name = checkpoint_path | |
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") | |
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) | |
if (dst_name / vq_model).exists(): | |
(dst_name / vq_model).unlink() | |
quantize_path = dst_name / "model.pth" | |
elif mode == "int4": | |
print( | |
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" | |
) | |
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) | |
quantized_state_dict = quant_handler.create_quantized_state_dict() | |
dir_name = checkpoint_path | |
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") | |
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) | |
if (dst_name / vq_model).exists(): | |
(dst_name / vq_model).unlink() | |
quantize_path = dst_name / "model.pth" | |
else: | |
raise ValueError( | |
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" | |
) | |
print(f"Writing quantized weights to {quantize_path}") | |
quantize_path.unlink(missing_ok=True) # remove existing file if one already there | |
torch.save(quantized_state_dict, quantize_path) | |
print(f"Quantization complete took {time.time() - t0:.02f} seconds") | |
if __name__ == "__main__": | |
quantize() | |