import torch import numpy as np blob = torch.load("./opt-125m-gptq4.pth") def verify_unpack_logic(prepack, pack, nbit=4): numel_per_int32 = 32//nbit qweight = pack['qweight'].numpy() scales = pack['scales'].numpy() #(ngroup, OC) qzeros = pack['qzeros'].numpy() #(ngroup, OC//numel_per_int32) IC = qweight.shape[0]*numel_per_int32 OC = qweight.shape[1] group_size = IC//scales.shape[0] qweight_unpack = np.zeros((IC,OC), dtype=np.float32) for row in range(0, qweight.shape[0]): for k in range(0, numel_per_int32): qweight_unpack[row*numel_per_int32+k, :] = ((qweight[row] >> k*nbit) & 0xF).astype(np.float32) # read as int32 and cast to float32 intweight_match = torch.allclose( torch.from_numpy(qweight_unpack).to(torch.int32), torch.from_numpy(pack['intweight'].astype(np.int32)) ) assert intweight_match, "intweight and qweight_unpack do not match! pls debug" scales_float = scales.astype(np.float32) # TODO: verify with asym zero point. sym zero points are all identical qzeros_unpack = np.zeros(list(scales.shape), dtype=np.float32) for i in range(0, numel_per_int32): # shift multiplier shift_multiplier = numel_per_int32 - 1 - i shift_by = shift_multiplier * nbit qzeros_unpack[:, i::numel_per_int32] = ((qzeros >> shift_by) & 0xF).astype(np.float32) # read as int32 and cast to float32 qzeros_unpack += 1 # for some reason they minus 1 qweight_unpack = torch.from_numpy(qweight_unpack).to('cuda').to(torch.float16) qzeros_unpack = torch.from_numpy(qzeros_unpack).to('cuda').to(torch.float16) scales_float = torch.from_numpy(scales_float).to('cuda').to(torch.float16) deqweight_unpack = torch.zeros((IC,OC), dtype=torch.float16) for i in range(IC): gid = i//group_size deqweight_unpack[i, :] = (qweight_unpack[i, :]-qzeros_unpack[gid, :]) * scales_float[gid, :] simulated_match = torch.allclose(deqweight_unpack, prepack['w'].t(), atol=0.0005) assert simulated_match, "prepack['w'] and deqweight_unpack do not match! pls debug" print(f"intweight_match: {intweight_match}, simulated_match: {simulated_match}") for layer, lblob in blob.items(): print(f"\n\n--> {layer}") prepack = lblob['prepack'] pack = lblob['pack'] # for k, v in prepack.items(): # print(f"prepack['{k:10}'] : {str(tuple(v.shape)):<20}") # for k, v in pack.items(): # print(f"pack['{k:13}'] : {str(tuple(v.shape)):<20}") verify_unpack_logic(prepack, pack)