|
import argparse |
|
import math |
|
import datetime |
|
import time |
|
import os |
|
import gc |
|
from tqdm import tqdm |
|
import copy |
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' |
|
|
|
import torch |
|
import torch.multiprocessing as mp |
|
from torch import nn, optim |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from datasets import load_dataset |
|
|
|
from lib import codebook, utils |
|
from lib.algo import quip, preprocess, outlier_channel_split as ocs |
|
|
|
import glog |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--num_cpu_threads', default=8, type=int) |
|
parser.add_argument('--batch_size', default=8, type=int) |
|
parser.add_argument('--devset_size', default=64, type=int) |
|
parser.add_argument('--ctx_size', default=2048, type=int) |
|
parser.add_argument('--save_path', type=str) |
|
parser.add_argument('--hessian_path', type=str) |
|
parser.add_argument('--base_model', default='meta-llama/Llama-2-70b-hf', type=str) |
|
parser.add_argument('--sigma_reg', default=1e-2, type=float) |
|
parser.add_argument('--sigma_reg2', default=1e-2, type=float) |
|
parser.add_argument('--incoh_mode', default='had', type=str, choices=['had', 'kron']) |
|
parser.add_argument('--lora_rank', default=0, type=int, help='if <=0 then turned off') |
|
parser.add_argument('--scale_override', default=-1, type=float) |
|
parser.add_argument('--codebook', default='D4', type=str) |
|
parser.add_argument('--quip_tune_iters', default=10, type=int) |
|
parser.add_argument('--remove_mean', action='store_true') |
|
parser.add_argument('--outlier_channel_split', action='store_true') |
|
parser.add_argument('--ocs_down_size', default=2**15, type=int) |
|
parser.add_argument('--use_fp64', action='store_true') |
|
parser.add_argument('--full_svd', action='store_true') |
|
parser.add_argument('--no_use_buffered', action='store_true') |
|
parser.add_argument('--q_buffer_size', default=2, type=int) |
|
parser.add_argument('--rescale_WH', action='store_true') |
|
parser.add_argument('--sample_proc', default=1, type=int) |
|
|
|
|
|
def quantize_kqv(layer, idx, cb, args, device='cpu', check_only=False): |
|
dtype_ = torch.float64 if args.use_fp64 else torch.float32 |
|
hatw_path = f'{args.save_path}/{idx}_qkv.pt' |
|
|
|
W_q = layer.self_attn.q_proj.weight |
|
W_k = layer.self_attn.k_proj.weight |
|
W_v = layer.self_attn.v_proj.weight |
|
W_q_scale = W_q.to(dtype_).square().mean().sqrt().to(dtype_) |
|
W_k_scale = W_k.to(dtype_).square().mean().sqrt().to(dtype_) |
|
W_v_scale = W_v.to(dtype_).square().mean().sqrt().to(dtype_) |
|
|
|
if os.path.exists(hatw_path): |
|
if check_only: |
|
return |
|
hatW = utils.load_quip(hatw_path, cb, args, device) |
|
glog.info(f'loaded saved hatW from {hatw_path}') |
|
else: |
|
H_data = torch.load(f'{args.hessian_path}/{idx}_qkv.pt', map_location=torch.device('cpu')) |
|
H = utils.flat_to_sym(H_data['flatH'], H_data['n']) |
|
mu = H_data['mu'] |
|
n = H_data['n'] |
|
W_qkv = torch.vstack((W_q.to(dtype_) / W_q_scale, W_k.to(dtype_) / W_k_scale, |
|
W_v.to(dtype_) / W_v_scale)).to(dtype_) |
|
H, mu = preprocess.basic_preprocess(H, mu, n, args) |
|
hatW, attr = quip.quantize(H, W_qkv, args.lora_rank, cb, args, device) |
|
attr.update({ |
|
'W_q_scale': W_q_scale.cpu(), |
|
'W_k_scale': W_k_scale.cpu(), |
|
'W_v_scale': W_v_scale.cpu(), |
|
}) |
|
torch.save(attr, hatw_path) |
|
utils.show_metrics(hatW, W_qkv, H.to(dtype_), f'layer {idx} qkv') |
|
utils.clean() |
|
|
|
W_q_next = (hatW[0:(W_q.shape[0]), :] * W_q_scale).half() |
|
W_k_next = (hatW[(W_q.shape[0]):(W_q.shape[0] + W_k.shape[0]), :] * W_k_scale).half() |
|
W_v_next = (hatW[(W_q.shape[0] + W_k.shape[0]):\ |
|
(W_q.shape[0] + W_k.shape[0] + W_v.shape[0]), :] * W_v_scale).half() |
|
|
|
if args.remove_mean: |
|
layer.self_attn.q_proj.bias = nn.Parameter( |
|
(W_q.to(dtype_) @ mu - W_q_next.to(dtype_) @ mu).half()) |
|
layer.self_attn.k_proj.bias = nn.Parameter( |
|
(W_k.to(dtype_) @ mu - W_k_next.to(dtype_) @ mu).half()) |
|
layer.self_attn.v_proj.bias = nn.Parameter( |
|
(W_v.to(dtype_) @ mu - W_v_next.to(dtype_) @ mu).half()) |
|
|
|
W_q.copy_(W_q_next) |
|
W_k.copy_(W_k_next) |
|
W_v.copy_(W_v_next) |
|
|
|
|
|
def quantize_o(layer, idx, cb, args, device='cpu', check_only=False): |
|
dtype_ = torch.float64 if args.use_fp64 else torch.float32 |
|
hatw_path = f'{args.save_path}/{idx}_o.pt' |
|
|
|
W_o = layer.self_attn.o_proj.weight |
|
W_o_scale = W_o.to(dtype_).square().mean().sqrt().to(dtype_) |
|
|
|
if os.path.exists(hatw_path): |
|
if check_only: |
|
return |
|
hatW = utils.load_quip(hatw_path, cb, args, device) |
|
glog.info(f'loading saved hatW from {hatw_path}') |
|
else: |
|
H_data = torch.load(f'{args.hessian_path}/{idx}_o.pt', map_location=torch.device('cpu')) |
|
H = utils.flat_to_sym(H_data['flatH'], H_data['n']) |
|
mu = H_data['mu'] |
|
n = H_data['n'] |
|
W_orig = W_o.to(dtype_) / W_o_scale |
|
H, mu = preprocess.basic_preprocess(H, mu, n, args) |
|
hatW, attr = quip.quantize(H, W_orig, args.lora_rank, cb, args, device) |
|
attr.update({'W_o_scale': W_o_scale}) |
|
torch.save(attr, hatw_path) |
|
utils.show_metrics(hatW, W_orig, H.to(dtype_), f'layer {idx} o') |
|
utils.clean() |
|
|
|
W_o_next = (hatW * W_o_scale).half() |
|
|
|
if args.remove_mean: |
|
layer.self_attn.o_proj.bias = nn.Parameter( |
|
(W_o.to(dtype_) @ mu - W_o_next.to(dtype_) @ mu).half()) |
|
|
|
W_o.copy_(W_o_next) |
|
|
|
|
|
def quantize_up(layer, idx, cb, args, device='cpu', check_only=False): |
|
dtype_ = torch.float64 if args.use_fp64 else torch.float32 |
|
hatw_path = f'{args.save_path}/{idx}_up.pt' |
|
|
|
W_up = layer.mlp.up_proj.weight |
|
W_gate = layer.mlp.gate_proj.weight |
|
W_up_scale = W_up.to(dtype_).square().mean().sqrt().to(dtype_) |
|
W_gate_scale = W_gate.to(dtype_).square().mean().sqrt().to(dtype_) |
|
|
|
if os.path.exists(hatw_path): |
|
if check_only: |
|
return |
|
glog.info(f'loading saved hatW from {hatw_path}') |
|
hatW = utils.load_quip(hatw_path, cb, args, device) |
|
else: |
|
H_data = torch.load(f'{args.hessian_path}/{idx}_up.pt', map_location=torch.device('cpu')) |
|
H = utils.flat_to_sym(H_data['flatH'], H_data['n']) |
|
mu = H_data['mu'] |
|
n = H_data['n'] |
|
W_upgate = torch.vstack( |
|
(W_up.to(dtype_) / W_up_scale, W_gate.to(dtype_) / W_gate_scale)).to(dtype_) |
|
H, mu = preprocess.basic_preprocess(H, mu, n, args) |
|
|
|
hatW, attr = quip.quantize(H, W_upgate, args.lora_rank, cb, args, device) |
|
attr.update({ |
|
'W_up_scale': W_up_scale, |
|
'W_gate_scale': W_gate_scale, |
|
}) |
|
torch.save(attr, hatw_path) |
|
utils.show_metrics(hatW, W_upgate, H.to(dtype_), f'layer {idx} up') |
|
utils.clean() |
|
|
|
W_up_next = (hatW[0:(W_up.shape[0]), :] * W_up_scale).half() |
|
W_gate_next = (hatW[(W_up.shape[0]):(W_up.shape[0] + W_gate.shape[0]), :] * W_gate_scale).half() |
|
|
|
if args.remove_mean: |
|
layer.mlp.up_proj.bias = nn.Parameter( |
|
(W_up.to(dtype_) @ mu - W_up_next.to(dtype_) @ mu).half()) |
|
layer.mlp.gate_proj.bias = nn.Parameter( |
|
(W_gate.to(dtype_) @ mu - W_gate_next.to(dtype_) @ mu).half()) |
|
|
|
W_up.copy_(W_up_next) |
|
W_gate.copy_(W_gate_next) |
|
|
|
|
|
def quantize_down(layer, idx, cb, args, device='cpu', check_only=False): |
|
dtype_ = torch.float64 if args.use_fp64 else torch.float32 |
|
hatw_path = f'{args.save_path}/{idx}_down.pt' |
|
|
|
W_down = layer.mlp.down_proj.weight |
|
W_down_scale = W_down.to(dtype_).square().mean().sqrt().to(dtype_) |
|
|
|
if os.path.exists(hatw_path): |
|
if check_only: |
|
return |
|
glog.info(f'loading saved hatW from {hatw_path}') |
|
hatW = utils.load_quip(hatw_path, cb, args, device) |
|
if args.outlier_channel_split: |
|
extra_inds = torch.load(hatw_path)['ocs_extra_inds'] |
|
else: |
|
H_data = torch.load(f'{args.hessian_path}/{idx}_down.pt', map_location=torch.device('cpu')) |
|
H = utils.flat_to_sym(H_data['flatH'], H_data['n']) |
|
mu = H_data['mu'] |
|
n = H_data['n'] |
|
if args.outlier_channel_split: |
|
|
|
glog.info(f'outlier channel splitting to {args.ocs_down_size}') |
|
W_down, H, mu, extra_inds, dupe_inds = ocs.outlier_channel_split( |
|
W_down, H, mu, args.ocs_down_size) |
|
n = args.ocs_down_size |
|
utils.clean() |
|
W_orig = W_down.to(dtype_) / W_down_scale |
|
H, mu = preprocess.basic_preprocess(H, mu, n, args) |
|
hatW, attr = quip.quantize(H, W_orig, args.lora_rank, cb, args, device) |
|
attr.update({'W_down_scale': W_down_scale}) |
|
if args.outlier_channel_split: |
|
attr['ocs_extra_inds'] = extra_inds |
|
attr['ocs_dupe_inds'] = dupe_inds |
|
torch.save(attr, hatw_path) |
|
utils.show_metrics(hatW, W_orig, H.to(dtype_), f'layer {idx} down') |
|
utils.clean() |
|
|
|
W_down_next = (hatW * W_down_scale).half() |
|
|
|
if args.remove_mean: |
|
layer.mlp.down_proj.bias = nn.Parameter( |
|
(W_down.to(dtype_) @ mu - W_down_next.to(dtype_) @ mu).half()) |
|
|
|
if args.outlier_channel_split: |
|
|
|
W_down_next = ocs.fuse_W(W_down_next, extra_inds) |
|
|
|
layer.mlp.down_proj.weight.copy_(W_down_next) |
|
|
|
|
|
def quantize_layer(layer, idx, cb, args, device='cpu', return_layer=False): |
|
|
|
|
|
|
|
torch.manual_seed(idx) |
|
torch.set_grad_enabled(False) |
|
|
|
utils.clean() |
|
quantize_kqv(layer, idx, cb, args, device, check_only=not return_layer) |
|
utils.clean() |
|
quantize_o(layer, idx, cb, args, device, check_only=not return_layer) |
|
utils.clean() |
|
quantize_up(layer, idx, cb, args, device, check_only=not return_layer) |
|
utils.clean() |
|
quantize_down(layer, idx, cb, args, device, check_only=not return_layer) |
|
utils.clean() |
|
|
|
glog.info(f'finished layer {idx}') |
|
if return_layer: |
|
return layer |
|
|
|
|
|
def quantize_layer_queue(in_q, cb, args, device): |
|
while True: |
|
next_item = in_q.get() |
|
if next_item is None: |
|
return |
|
quantize_layer(*next_item, cb, args, device, False) |
|
|
|
|
|
def main(args): |
|
dtype_ = torch.float64 if args.use_fp64 else torch.float32 |
|
|
|
cb = codebook.get_codebook(args.codebook) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.base_model, |
|
torch_dtype='auto', |
|
low_cpu_mem_usage=True) |
|
|
|
|
|
all_config = {'quant_args': args, 'model_config': model.config} |
|
all_config['model_config'].update({ |
|
'quip_params': { |
|
'outlier_channel_split': args.outlier_channel_split, |
|
'lora_rank': args.lora_rank, |
|
'rescale_WH': args.rescale_WH, |
|
'codebook': args.codebook, |
|
'codebook_version': cb.version, |
|
'codesz': cb.codesz, |
|
'idx_dtype': str(cb.idx_dtype), |
|
'fused': True, |
|
'packsz': cb.packsz, |
|
} |
|
}) |
|
if args.outlier_channel_split: |
|
all_config['model_config'].quip_params['ocs_down_size'] = args.ocs_down_size |
|
torch.save(all_config, os.path.join(args.save_path, 'config.pt')) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
glog.info('loaded model') |
|
|
|
dataset = load_dataset('togethercomputer/RedPajama-Data-1T-Sample', split='train') |
|
devset = utils.sample_devset(dataset, tokenizer, args.devset_size, args.ctx_size, |
|
args.sample_proc) |
|
glog.info('loaded dataset and devset') |
|
|
|
|
|
nproc = torch.cuda.device_count() |
|
|
|
if nproc > 1: |
|
|
|
|
|
layer_q = mp.Queue(maxsize=args.q_buffer_size) |
|
|
|
quantize_procs = [] |
|
for i in range(nproc): |
|
p = mp.Process(target=quantize_layer_queue, args=(layer_q, cb, args, i)) |
|
p.start() |
|
quantize_procs.append(p) |
|
for _ in range(len(model.model.layers)): |
|
layer_q.put((copy.deepcopy(model.model.layers[_]), _)) |
|
for p in quantize_procs: |
|
layer_q.put(None) |
|
for p in quantize_procs: |
|
p.join() |
|
|
|
glog.info('done quantizing') |
|
|
|
|
|
device = 0 |
|
|
|
|
|
orig_emb = model.model.embed_tokens(devset) |
|
quant_emb = orig_emb.clone() |
|
position_ids = torch.arange(args.ctx_size, dtype=torch.int32)[None, :].to(device) + \ |
|
torch.zeros(args.batch_size, args.ctx_size, dtype=torch.int32).to(device) |
|
if hasattr(model.config, 'sliding_window'): |
|
attention_mask = model.model._prepare_decoder_attention_mask( |
|
torch.ones(args.batch_size, args.ctx_size, |
|
dtype=torch.bool), (args.batch_size, args.ctx_size), |
|
quant_emb[0:args.batch_size], |
|
0, |
|
sliding_window=model.config.sliding_window).to(device) |
|
else: |
|
attention_mask = model.model._prepare_decoder_attention_mask( |
|
torch.ones(args.batch_size, args.ctx_size, dtype=torch.bool), |
|
(args.batch_size, args.ctx_size), quant_emb[0:args.batch_size], 0).to(device) |
|
|
|
for i in range(len(model.model.layers)): |
|
model.model.layers[i] = model.model.layers[i].to(device) |
|
|
|
for j in range(args.devset_size // args.batch_size): |
|
orig_emb[args.batch_size * j : args.batch_size * (j + 1)] = \ |
|
model.model.layers[i]( |
|
orig_emb[args.batch_size * j : args.batch_size * (j + 1)].to(device), |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
output_attentions=False)[0].cpu() |
|
|
|
model.model.layers[i] = model.model.layers[i].cpu() |
|
|
|
model.model.layers[i] = quantize_layer(model.model.layers[i], |
|
i, |
|
cb, |
|
args, |
|
device=device, |
|
return_layer=True).to(device) |
|
|
|
for j in range(args.devset_size // args.batch_size): |
|
quant_emb[args.batch_size * j : args.batch_size * (j + 1)] = \ |
|
model.model.layers[i]( |
|
quant_emb[args.batch_size * j : args.batch_size * (j + 1)].to(device), |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
output_attentions=False)[0].cpu() |
|
|
|
model.model.layers[i] = model.model.layers[i].cpu() |
|
model.model.layers[i] = None |
|
|
|
act_error = (quant_emb.to(dtype_) - orig_emb.to(dtype_)).square().sum() / \ |
|
(orig_emb.to(dtype_) - orig_emb.to(dtype_).mean((0, 1))).square().sum() |
|
|
|
glog.info(f'layer {i} activation error {act_error}') |
|
|
|
glog.info('calculating perplexity on devset') |
|
|
|
lm_head = model.lm_head.to(dtype_) |
|
lm_head.to(device) |
|
|
|
norm = model.model.norm.to(dtype_) |
|
norm.to(device) |
|
|
|
acc = 0.0 |
|
for i in tqdm(range(args.devset_size // args.batch_size), desc='original model perplexity'): |
|
shift_logits = lm_head( |
|
norm(orig_emb[args.batch_size * i:args.batch_size * |
|
(i + 1)].to(device).to(dtype_)))[..., :-1, :].contiguous().view( |
|
-1, model.config.vocab_size) |
|
shift_labels = devset[args.batch_size * i:args.batch_size * (i + 1), |
|
1:].contiguous().view(-1).to(device) |
|
loss_fct = nn.CrossEntropyLoss().to(device) |
|
acc += loss_fct(shift_logits, shift_labels) |
|
perplexity = (acc / (args.devset_size // args.batch_size + 1)).exp() |
|
glog.info(f'original model perplexity: {perplexity}') |
|
|
|
acc = 0.0 |
|
for i in tqdm(range(args.devset_size // args.batch_size), desc='quantized model perplexity'): |
|
shift_logits = lm_head( |
|
norm(quant_emb[args.batch_size * i:args.batch_size * |
|
(i + 1)].to(device).to(dtype_)))[..., :-1, :].contiguous().view( |
|
-1, model.config.vocab_size) |
|
shift_labels = devset[args.batch_size * i:args.batch_size * (i + 1), |
|
1:].contiguous().view(-1).to(device) |
|
loss_fct = nn.CrossEntropyLoss().to(device) |
|
acc += loss_fct(shift_logits, shift_labels) |
|
perplexity = (acc / (args.devset_size // args.batch_size + 1)).exp() |
|
glog.info(f'quantized model perplexity: {perplexity}') |
|
|
|
|
|
if __name__ == '__main__': |
|
torch.set_grad_enabled(False) |
|
mp.set_start_method('spawn') |
|
args = parser.parse_args() |
|
torch.set_num_threads(args.num_cpu_threads) |
|
torch.manual_seed(args.seed) |
|
os.makedirs(args.save_path, exist_ok=True) |
|
main(args) |
|
|