Vui Seng Chua
Add content
cfb9114
raw
history blame contribute delete
No virus
3.75 kB
import transformers
import torch
import torch.nn as nn
import numpy as np
from transformers import LlamaForCausalLM, AutoModelForCausalLM, AutoTokenizer
from fake_dequantize import fake_dequantize
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear
DEBUG=False
class SparseCompressLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, verbose=DEBUG):
super(SparseCompressLinear, self).__init__(in_features, out_features, bias)
self.verbose = verbose # for debug
def forward(self, input):
if self.verbose is True:
print("SparseCompressLinear Forward!")
return super(SparseCompressLinear, self).forward(input)
def __repr__(self):
# Custom print out
return f"SparseCompressLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None})"
def make_linear_from_QuantLinear(QuantLinearObj):
device = QuantLinearObj.scales.device
qweight = QuantLinearObj.qweight
scales = QuantLinearObj.scales
qzeros = QuantLinearObj.qzeros
with torch.no_grad():
W, scales, zeros = fake_dequantize(qweight, scales, qzeros)
IC, OC = W.shape
linear = SparseCompressLinear(in_features=IC, out_features=OC, bias=(QuantLinearObj.bias != None))
assert linear.weight.shape == W.t().shape, "Logical Error"
linear.weight.data = W.t().contiguous()
if QuantLinearObj.bias is not None:
linear.bias.data = QuantLinearObj.bias
linear.register_buffer("scales", scales)
linear.register_buffer("zeros", zeros)
return linear.to(device)
def replace_QuantLinear_with_SparseCompressLinear(model):
for name, module in model.named_children():
if isinstance(module, QuantLinear):
if DEBUG is True:
print(f"Restoring {name}")
restored_linear = make_linear_from_QuantLinear(module)
restored_linear = restored_linear.to(torch.float16) #TODO: Hardcoding
setattr(model, name, restored_linear)
else:
# Recursively apply to child modules
replace_QuantLinear_with_SparseCompressLinear(module)
return model
if __name__ == "__main__":
# model_id = "/data4/vchua/hf-model/Meta-Llama-3-8B-Instruct"
# model_id = "/data4/vchua/hf-model/Meta-Llama-3-70B"
model_id = "/home/vchua/sqft-qa-sparsepeft-llama-3-8b-50-gptq-gsm8k"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Alan Turing theorized that computers would one day become"
input_ids = tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids)
# -----------------------------------------
output_ids = model.generate(
input_ids.cuda(), do_sample=False, top_p=None, num_beams=1, max_new_tokens=256
)
output_sqft = tokenizer.batch_decode(output_ids.cpu())
print(f"\n++ Baseline sqft output:\n\n{output_sqft[0]}\n\n")
# -----------------------------------------
replace_QuantLinear_with_SparseCompressLinear(model)
output_ids = model.generate(
input_ids.cuda(), do_sample=False, top_p=None, num_beams=1, max_new_tokens=256
)
output_fake_dequantize = tokenizer.batch_decode(output_ids.cpu())
print(f"\n++ fake dequantize sqft output:\n\n{output_fake_dequantize[0]}\n\n")
tx1mlp = model.model.layers[0].mlp
torch.save(tx1mlp.state_dict(), "./sqft_llama3_8B_gptq_tx1_mlp.pth")
# -----------------------------------------
print()
# torch.save(tx1mlp.state_dict(), "./sqft_llama3_8B_gptq_tx1_mlp.pth")