|
import torch |
|
import numpy as np |
|
import os |
|
|
|
def calc_sparsity(tensor): |
|
if isinstance(tensor, torch.Tensor): |
|
nnz = tensor.count_nonzero() |
|
rate = 1-(nnz/tensor.numel()) |
|
return rate.item(), nnz |
|
else: |
|
nnz = np.count_nonzero(tensor) |
|
rate = 1-(nnz/tensor.size) |
|
return rate, nnz |
|
|
|
if __name__ == "__main__": |
|
sd = torch.load("./sqft_llama3_8B_gptq_tx1_mlp.pth") |
|
|
|
for k,v in sd.items(): |
|
print(k) |
|
|
|
weight = sd['up_proj.weight'] |
|
scales = sd['up_proj.scales'] |
|
zeros = sd['up_proj.zeros'] |
|
|
|
nbit=4 |
|
OC, IC = weight.shape |
|
numel_per_int32 = 32//nbit |
|
|
|
stride_oc = 16 |
|
stride_ic = 128 * 8 // nbit |
|
|
|
|
|
weight = weight.contiguous() |
|
scales = scales.t().contiguous() |
|
zeros = zeros.t().contiguous() |
|
|
|
|
|
group_size = 32 |
|
scales = scales.repeat_interleave(4, dim=1) |
|
zeros = zeros.repeat_interleave(4, dim=1) |
|
|
|
|
|
tiled_weight = weight.unfold(0, stride_oc, stride_oc).unfold(1, stride_ic, stride_ic) |
|
tiled_scales = scales.unfold(0, stride_oc, stride_oc).unfold(1, stride_ic//group_size, stride_ic//group_size) |
|
tiled_zeros = zeros.unfold(0, stride_oc, stride_oc).unfold(1, stride_ic//group_size, stride_ic//group_size) |
|
|
|
assert tiled_weight.shape[:2] == tiled_scales.shape[:2], "pls debug" |
|
assert tiled_weight.shape[:2] == tiled_zeros.shape[:2], "pls debug" |
|
|
|
tiled_qweight = torch.zeros_like(tiled_weight) |
|
tiled_bitmap = torch.zeros_like(tiled_weight).to(torch.bool) |
|
tiled_nnz = torch.zeros(tiled_weight.shape[:2]).to(torch.int16) |
|
|
|
non_zero_removed_tiled_qweight = torch.zeros_like(tiled_weight) |
|
for tile_r in range(0, tiled_weight.shape[0]): |
|
for tile_c in range(0, tiled_weight.shape[1]): |
|
|
|
|
|
sparsity, nnz = calc_sparsity(tiled_weight[tile_r, tile_c]) |
|
print(f"tile [{tile_r:4},{tile_c:4}], sparsity: {sparsity*100:4.1f}%, nnz: {nnz:5}") |
|
|
|
|
|
nonzero_bool = (tiled_weight[tile_r, tile_c] != 0) |
|
assert nonzero_bool.sum() == nnz, "pls debug" |
|
tiled_bitmap[tile_r, tile_c] = nonzero_bool |
|
tiled_nnz[tile_r, tile_c] = nnz |
|
|
|
r = tile_r |
|
c = tile_c |
|
|
|
|
|
w = tiled_weight[r, c] |
|
qw = torch.zeros_like(tiled_weight[r, c]) |
|
s = tiled_scales[r, c] |
|
z = tiled_zeros[r, c] |
|
|
|
|
|
for col in range(tiled_scales.shape[-1]): |
|
sidx = col*group_size |
|
eidx = (col+1)*group_size |
|
|
|
|
|
qw[:, sidx:eidx] = ( w[:, sidx:eidx] + (s[:,col]*z[:,col]).unsqueeze(-1) ) / s[:,col].unsqueeze(-1) |
|
|
|
|
|
non_zero_removed_tiled_qweight[r, c]=qw |
|
|
|
|
|
assert len(qw[nonzero_bool]) == nnz, "pls debug" |
|
compress_qw = (torch.ones_like(qw)*8).reshape(-1) |
|
compress_qw[:nnz] = qw[nonzero_bool] |
|
assert (compress_qw != 8).sum() == nnz, "pls debug" |
|
compress_qw = compress_qw.reshape(qw.shape) |
|
|
|
tiled_qweight[r, c] = compress_qw |
|
|
|
|
|
|
|
|
|
tiled_qweight = tiled_qweight.to(torch.int32).contiguous() |
|
tiled_zeros = tiled_zeros.to(torch.int32).contiguous() |
|
tiled_scales = tiled_scales.to(torch.float16).contiguous() |
|
tiled_bitmap = tiled_bitmap.to(torch.int32).contiguous() |
|
tiled_nnz = tiled_nnz.to(torch.int16).contiguous() |
|
|
|
|
|
linear_nnz = tiled_nnz |
|
linear_scales = tiled_scales.reshape(-1) |
|
|
|
linear_qweight = tiled_qweight.reshape(-1).reshape(-1, 8).cpu().numpy() |
|
linear_qweight_pack = np.zeros((linear_qweight.shape[0], 1), dtype=np.int32) |
|
for i in range(0, numel_per_int32): |
|
linear_qweight_pack[:, 0] |= linear_qweight[:, i] << (numel_per_int32 - 1 - i)*nbit |
|
linear_qweight_pack = linear_qweight_pack.reshape(-1) |
|
|
|
linear_zeros = tiled_zeros.reshape(-1).reshape(-1, 8).cpu().numpy() |
|
linear_zeros_pack = np.zeros((linear_zeros.shape[0], 1), dtype=np.int32) |
|
for i in range(0, numel_per_int32): |
|
linear_zeros_pack[:, 0] |= linear_zeros[:, i] << (numel_per_int32 - 1 - i)*nbit |
|
linear_zeros_pack = linear_zeros_pack.reshape(-1) |
|
|
|
linear_bitmap = tiled_bitmap.reshape(-1).reshape(-1, 32).cpu().numpy() |
|
linear_bitmap_pack = np.zeros((linear_bitmap.shape[0], 1), dtype=np.int32) |
|
for i in range(0, 32): |
|
linear_bitmap_pack[:, 0] |= linear_bitmap[:, i] << (32 - 1 - i) |
|
linear_bitmap_pack = linear_bitmap_pack.reshape(-1) |
|
|
|
os.makedirs("sparse_w4", exist_ok=True) |
|
linear_qweight_pack.tofile('sparse_w4/linear_compressed_qweight_int32.bin') |
|
linear_zeros_pack.tofile('sparse_w4/linear_zeros_int32.bin') |
|
linear_scales.cpu().contiguous().numpy().tofile('sparse_w4/linear_scales_float16.bin') |
|
linear_bitmap_pack.tofile('sparse_w4/linear_bitmap_int32.bin') |
|
linear_nnz.cpu().contiguous().numpy().tofile('sparse_w4/linear_nnz_int16.bin') |
|
|
|
print("joto") |
|
|
|
loaded_linear_nnz = np.fromfile("sparse_w4/linear_nnz_int16.bin", dtype=np.int16) |
|
loaded_tiled_nnz = loaded_linear_nnz.reshape(896,16) |
|
|
|
assert torch.all(torch.from_numpy(loaded_tiled_nnz) == tiled_nnz), "pls debug" |
|
|
|
loaded_linear_scales = np.fromfile("sparse_w4/linear_scales_float16.bin", dtype=np.float16) |
|
loaded_tiled_scales = loaded_linear_scales.reshape(896, 16, 16, 8) |
|
|
|
assert torch.all(torch.from_numpy(loaded_tiled_scales).to("cuda") == tiled_scales), "pls debug" |
|
|
|
loaded_linear_bitmap_pack = np.fromfile('sparse_w4/linear_bitmap_int32.bin', dtype=np.int32) |
|
loaded_linear_bitmap_pack = np.expand_dims(loaded_linear_bitmap_pack, axis=-1) |
|
loaded_linear_bitmap = np.zeros((loaded_linear_bitmap_pack.shape[0], 32), dtype=np.int32) |
|
for i in range(0, 32): |
|
loaded_linear_bitmap[:, i] = ( loaded_linear_bitmap_pack[:, 0] >> (32 - 1 - i) ) & 0x1 |
|
loaded_tiled_bitmap = loaded_linear_bitmap.reshape(-1).reshape(896, 16, 16, 256) |
|
|
|
assert torch.all(torch.from_numpy(loaded_tiled_bitmap).to("cuda") == tiled_bitmap), "pls debug" |
|
|
|
loaded_linear_qweight_pack = np.fromfile('sparse_w4/linear_compressed_qweight_int32.bin', dtype=np.int32) |
|
loaded_linear_qweight_pack = np.expand_dims(loaded_linear_qweight_pack, axis=-1) |
|
loaded_linear_qweight = np.zeros((loaded_linear_qweight_pack.shape[0], numel_per_int32), dtype=np.int32) |
|
for i in range(0, numel_per_int32): |
|
loaded_linear_qweight[:, i] = ( loaded_linear_qweight_pack[:, 0] >> (numel_per_int32 - 1 - i)*nbit ) & 0xF |
|
loaded_tiled_qweight = loaded_linear_qweight.reshape(-1).reshape(896, 16, 16, 256) |
|
|
|
assert torch.all(torch.from_numpy(loaded_tiled_qweight).to("cuda") == tiled_qweight), "pls debug" |
|
|
|
loaded_linear_zeros_pack = np.fromfile('sparse_w4/linear_zeros_int32.bin', dtype=np.int32) |
|
loaded_linear_zeros_pack = np.expand_dims(loaded_linear_zeros_pack, axis=-1) |
|
loaded_linear_zeros = np.zeros((loaded_linear_zeros_pack.shape[0], numel_per_int32), dtype=np.int32) |
|
for i in range(0, numel_per_int32): |
|
loaded_linear_zeros[:, i] = ( loaded_linear_zeros_pack[:, 0] >> (numel_per_int32 - 1 - i)*nbit ) & 0xF |
|
loaded_tiled_zeros = loaded_linear_zeros.reshape(-1).reshape(896, 16, 16, 8) |
|
|
|
assert torch.all(torch.from_numpy(loaded_tiled_zeros).to("cuda") == tiled_zeros), "pls debug" |
|
|
|
zero_recovered_tiles = np.ones_like(loaded_tiled_qweight)*8 |
|
for r in range(0, loaded_tiled_qweight.shape[0]): |
|
for c in range(0, loaded_tiled_qweight.shape[1]): |
|
zero_removed_padded_tile = loaded_tiled_qweight[r, c] |
|
nnz=loaded_tiled_nnz[r, c] |
|
tile_values = zero_removed_padded_tile.reshape(-1)[0:nnz] |
|
nnz_indices = np.nonzero(loaded_tiled_bitmap[r, c]) |
|
zero_recovered_tiles[r, c][nnz_indices] = tile_values |
|
|
|
assert torch.all(non_zero_removed_tiled_qweight.to(torch.int32) == torch.from_numpy(zero_recovered_tiles).to("cuda")), "pls debug" |
|
|
|
dequantized_tiles = np.zeros_like(zero_recovered_tiles, dtype=np.float16) |
|
|
|
zero_recovered_tiles = zero_recovered_tiles.astype(np.float16) |
|
loaded_tiled_zeros = loaded_tiled_zeros.astype(np.float16) |
|
loaded_tiled_scales = loaded_tiled_scales.astype(np.float16) |
|
for i in range(0, zero_recovered_tiles.shape[-1], group_size): |
|
gid = i//group_size |
|
dequantized_tiles[:, :, :, i:i+group_size] = \ |
|
( zero_recovered_tiles[:, :, :, i:i+group_size] - \ |
|
np.expand_dims(loaded_tiled_zeros[:, :, :, gid], axis=-1) ) * \ |
|
np.expand_dims(loaded_tiled_scales[:, :, :, gid], axis=-1) |
|
|
|
print("joto") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|