import gradio as gr import tqdm import torch from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from functools import partial import gc def get_model_size(model: nn.Module, data_width=16, group_size=-1): if group_size != -1: data_width += (16 + 4) / group_size num_elements = 0 for param in model.parameters(): num_elements += param.numel() return num_elements * data_width Byte = 8 KiB = 1024 * Byte MiB = 1024 * KiB GiB = 1024 * MiB # core quantization method (simulated quantization) def pseudo_quantize_tensor(w, n_bit=4, q_group_size=-1): org_w_shape = w.shape if q_group_size > 0: assert org_w_shape[-1] % q_group_size == 0 w = w.reshape(-1, q_group_size) assert w.dim() == 2 # Calculate the maximum (\alpha) and minimum values (\beta) in the tensor. max_val = w.amax(dim=1, keepdim=True) assert max_val.dim() == 2 and max_val.size(0) == w.size(0) and max_val.size(1) == 1 min_val = w.amin(dim=1, keepdim=True) assert min_val.dim() == 2 and min_val.size(0) == w.size(0) and min_val.size(1) == 1 # Calculate the scale factor and zero point. (Formula 1 & 2) max_int = 2 ** n_bit - 1 scales = (max_val - min_val).clamp(min=1e-5) / max_int assert scales.shape == max_val.shape zeros = (-torch.round(min_val / scales)).clamp_(0, max_int) assert scales.shape == min_val.shape assert torch.isnan(scales).sum() == 0 assert torch.isnan(w).sum() == 0 # Quantize W: Map values in the range [\beta, \alpha] to lie within [0, 2^b - 1] (Formula 3) w = torch.clamp(torch.round(w / scales) + zeros, 0, max_int) assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size # Dequantize W (pseudo quantization, the inverse transformation of Formula 3) w = (w - zeros) * scales assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size assert torch.isnan(w).sum() == 0 w = w.reshape(org_w_shape) return w @torch.no_grad() def pseudo_quantize_model_weight( model, w_bit, q_group_size, ): for n, m in model.named_modules(): if isinstance(m, nn.Linear): = pseudo_quantize_tensor(, n_bit=w_bit, q_group_size=q_group_size) # Load the tokenizer and model model_path = "facebook/opt-125m" model_q_path = "facebook/opt-125m_3bit" offload_folder = "offload" model = AutoModelForCausalLM.from_pretrained(model_q_path, device_map="auto", offload_folder=offload_folder) generator = pipeline('text-generation', model="facebook/opt-125m_3bit") #generator_q = pipeline('text-generation', model="facebook/opt-125m-awq") def generate_text_pip(prompt): generated_text = generator(prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] return generated_text ''' def generate_text_pip_q(prompt): generated_text = generator_q(prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] return generated_text ''' print(generator("I went to boston and")) #print("quantized model",generator_q("I went to boston and")) ''' def generate_text(prompt): inputs = tokenizer(prompt, return_tensors="pt") output = model(**inputs) logits = output.logits predicted_ids = logits.argmax(-1) #generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True) generated_text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0] return generated_text def generate_text_from_quantized(prompt): inputs = tokenizer(prompt, return_tensors="pt") output = model_q(**inputs) logits = output.logits predicted_ids = logits.argmax(-1) #generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True) generated_text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0] return generated_text ''' # Create a Gradio interface model_size = get_model_size(model, data_width=32, group_size=128) description = f"Model Name : OPT-1.3b
Original Model Size : 4.21 GB
Quantized Model Size : {model_size/MiB:.2f} MiB" iface = gr.Interface(fn=generate_text_pip, inputs="text", outputs="text", description=description) #iface_2 = gr.Interface(fn=generate_text_pip_q, inputs="text", outputs="text") iface.launch() #app = gr.TabbedInterface([iface, iface_2],["Normal", "Quantized"]) # Launch the Gradio app #app.launch()