File size: 2,645 Bytes
35e23cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch
import torch.nn as nn 
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import transformers
import torch
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import warnings
warnings.filterwarnings("ignore")
device = "cuda"

tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")

from datasets import load_dataset

test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
# print(len(test))
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")
import time
import gc
def run_experiment(model):
    print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}')
    max_length = model.config.n_positions
    stride = 512
    seq_len = encodings.input_ids.size(1)

    nlls = []
    start_time = time.time()
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            neg_log_likelihood = outputs.loss

        if begin_loc == 0:
            print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}')
        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(f'Loss = {ppl.item()}')
    print(f'Time taken: {- start_time + time.time()}')


from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
)
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config )

## 4bit
print('4 bit model')
run_experiment(model)

torch.save(model, 'bnb-4.pth')
print()

## 8bit
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config )
print('8 bit model')
run_experiment(model)
torch.save(model, 'bnb-8.pth')
print()


## nf4 bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
model =AutoModelForCausalLM.from_pretrained("gpt2", quantization_config=bnb_config )
print('4 bit nf4 model')
run_experiment(model)
torch.save(model, 'bnb-nf4.pth')
print()