import gradio as gr
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from functools import partial
import gc
def get_model_size(model: nn.Module, data_width=16, group_size=-1):
if group_size != -1:
data_width += (16 + 4) / group_size
num_elements = 0
for param in model.parameters():
num_elements += param.numel()
return num_elements * data_width
Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=4, q_group_size=-1):
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
# Calculate the maximum (\alpha) and minimum values (\beta) in the tensor.
max_val = w.amax(dim=1, keepdim=True)
assert max_val.dim() == 2 and max_val.size(0) == w.size(0) and max_val.size(1) == 1
min_val = w.amin(dim=1, keepdim=True)
assert min_val.dim() == 2 and min_val.size(0) == w.size(0) and min_val.size(1) == 1
# Calculate the scale factor and zero point. (Formula 1 & 2)
max_int = 2 ** n_bit - 1
scales = (max_val - min_val).clamp(min=1e-5) / max_int
assert scales.shape == max_val.shape
zeros = (-torch.round(min_val / scales)).clamp_(0, max_int)
assert scales.shape == min_val.shape
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
# Quantize W: Map values in the range [\beta, \alpha] to lie within [0, 2^b - 1] (Formula 3)
w = torch.clamp(torch.round(w / scales) + zeros, 0, max_int)
assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size
# Dequantize W (pseudo quantization, the inverse transformation of Formula 3)
w = (w - zeros) * scales
assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
return w
@torch.no_grad()
def pseudo_quantize_model_weight(
model, w_bit, q_group_size,
):
for n, m in model.named_modules():
if isinstance(m, nn.Linear):
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)
# Load the tokenizer and model
model_path = "facebook/opt-125m"
model_q_path = "facebook/opt-125m_3bit"
offload_folder = "offload"
model = AutoModelForCausalLM.from_pretrained(model_q_path, device_map="auto", offload_folder=offload_folder)
generator = pipeline('text-generation', model="facebook/opt-125m_3bit")
#generator_q = pipeline('text-generation', model="facebook/opt-125m-awq")
def generate_text_pip(prompt):
generated_text = generator(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
return generated_text
'''
def generate_text_pip_q(prompt):
generated_text = generator_q(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
return generated_text
'''
print(generator("I went to boston and"))
#print("quantized model",generator_q("I went to boston and"))
'''
def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
output = model(**inputs)
logits = output.logits
predicted_ids = logits.argmax(-1)
#generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
generated_text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return generated_text
def generate_text_from_quantized(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
output = model_q(**inputs)
logits = output.logits
predicted_ids = logits.argmax(-1)
#generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
generated_text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return generated_text
'''
# Create a Gradio interface
model_size = get_model_size(model, data_width=32, group_size=128)
description = f"Model Name : OPT-1.3b
Original Model Size : 4.21 GB
Quantized Model Size : {model_size/MiB:.2f} MiB"
iface = gr.Interface(fn=generate_text_pip, inputs="text", outputs="text", description=description)
#iface_2 = gr.Interface(fn=generate_text_pip_q, inputs="text", outputs="text")
iface.launch()
#app = gr.TabbedInterface([iface, iface_2],["Normal", "Quantized"])
# Launch the Gradio app
#app.launch()