File size: 3,704 Bytes
0cd3f2f
 
 
 
5557f4f
0cd3f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5557f4f
0cd3f2f
 
5557f4f
 
 
 
 
 
 
 
 
0cd3f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5557f4f
0cd3f2f
 
 
 
 
 
 
5557f4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from functools import partial
import gc


# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=4, q_group_size=-1):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)

    assert w.dim() == 2

    # Calculate the maximum (\alpha) and minimum values (\beta) in the tensor.
    max_val = w.amax(dim=1, keepdim=True)
    assert max_val.dim() == 2 and max_val.size(0) == w.size(0) and max_val.size(1) == 1
    min_val = w.amin(dim=1, keepdim=True)
    assert min_val.dim() == 2 and min_val.size(0) == w.size(0) and min_val.size(1) == 1

    # Calculate the scale factor and zero point.  (Formula 1 & 2)
    max_int = 2 ** n_bit - 1
    scales = (max_val - min_val).clamp(min=1e-5) / max_int
    assert scales.shape == max_val.shape
    zeros = (-torch.round(min_val / scales)).clamp_(0, max_int)
    assert scales.shape == min_val.shape

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    # Quantize W: Map values in the range [\beta, \alpha] to lie within [0, 2^b - 1] (Formula 3)
    w = torch.clamp(torch.round(w / scales) + zeros, 0, max_int)
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    # Dequantize W (pseudo quantization, the inverse transformation of Formula 3)
    w = (w - zeros) * scales
    assert w.dim() == 2 and w.size(0) == scales.size(0) and w.size(1) == q_group_size

    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)
    return w

@torch.no_grad()
def pseudo_quantize_model_weight(
    model, w_bit, q_group_size,
):
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, q_group_size=q_group_size)
            
            
    
            
# Load the tokenizer and model
model_path = "facebook/opt-1.3b"
offload_folder = "offload"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", offload_folder=offload_folder)

model_q = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",offload_folder=offload_folder)
pseudo_quantize_model_weight(model_q, w_bit=3, q_group_size=128)
# Define a function for model inference


generator = pipeline('text-generation', model="facebook/opt-1.3b")

def generate_text_pip(prompt):
    generated_text = generator(prompt, max_length=1000, num_return_sequences=1)[0]['generated_text']
    return generated_text
print(generator("I went to boston and"))

def generate_text(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model(**inputs)
    logits = output.logits
    predicted_ids = logits.argmax(-1)
    generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    return generated_text

def generate_text_from_quantized(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model_q(**inputs)
    logits = output.logits
    predicted_ids = logits.argmax(-1)
    generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    return generated_text

# Create a Gradio interface
iface = gr.Interface(fn=generate_text_pip, inputs="text", outputs="text", live=True)

iface_2 = gr.Interface(fn=generate_text_from_quantized, inputs="text", outputs="text", live=True)


app = gr.TabbedInterface([iface, iface_2],["Normal", "Quantized"])

# Launch the Gradio app
app.launch(server_name="0.0.0.0", share=True)