File size: 6,582 Bytes
33ad5e9
 
2f9c62d
75fa479
2f9c62d
b2a1f5e
75fa479
 
 
 
 
2f9c62d
 
 
 
 
 
 
 
4971af2
2f9c62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247095
2f9c62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376ac62
2f9c62d
 
 
1ba8dbb
2f9c62d
 
 
4971af2
 
 
 
 
2f9c62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4971af2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gc
import time
from threading import Thread

import gradio as gr

import torch
from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
from configuration_tricksy import TricksyConfig

def generate_text(prompt, max_new_tokens, top_k, top_p, use_tricksy):
    set_seed(42)
    model_name = 'facebook/opt-6.7b'
    disk_weights = OPTDiskWeights(model_name)
    tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    inputs = tokenizer(prompt[:500], return_tensors='pt').input_ids.to('cuda')

    generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
    thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ''
    for new_text in streamer:
        generated_text += new_text
        yield generated_text, ''
    
    stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
    stats_text += f'  \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
    stats_text += f'  \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'

    disk_weights = None
    tricksy_model.clear()
    tricksy_model = None
    time.sleep(.2)
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    yield generated_text, stats_text

css = """
h1 {
    text-align: center;
    display:block;
}
"""

with gr.Blocks(css=css) as iface:
    gr.Markdown('# Tricksy-OPT 6.7b')
    with gr.Row():
        with gr.Column():
            prompt = gr.Text(label="Prompt", value='Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')
            with gr.Accordion("Additional inputs"):
                use_tricksy = gr.Checkbox(value=True, label="Use Tricksy", info="If true, only send the sparse MLP weight diff to the GPU. If false, send the all weights to the GPU.")
                max_new_tokens = gr.Slider(minimum=1, maximum=100, value=100, label="Max new tokens")
                top_k = gr.Slider(minimum=1, maximum=500, value=50, label="Top-k sampling")
                top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p (nucleus sampling)")
                use_tricksy.change(
                    lambda x: 100 if x else 10,
                    inputs=[use_tricksy],
                    outputs=[max_new_tokens]
                )
                max_new_tokens.change(
                    lambda x, y: 10 if x > 10 and not y else x,
                    inputs=[max_new_tokens, use_tricksy],
                    outputs=[max_new_tokens]
                )
        with gr.Column():
            out = gr.Textbox(label="Generated Text")
            stats = gr.Textbox(label="Statistics")
    btn = gr.Button("Generate")
    btn.click(
        generate_text,
        inputs=[
            prompt,
            max_new_tokens,
            top_k,
            top_p,
            use_tricksy,
        ],
        outputs=[out, stats]
    )
    with gr.Accordion("Description", open=False):
        gr.Markdown('''
            MLP layers of large language models are naturally sparse--e.g. > 99% of layer 3's and > 90% of layer 20's neurons in OPT-1.3b have no effect (due to relu) for most inputs. Adjacent tokens also share a significant number of active neurons--e.g. for layers 1-7 of OPT-1.3b, > 90% of neurons active for token k are also active for token k + 1 (and 60-65% for layers 20-23).

            We exploit this natural sparisity to minimize CPU-GPU data transfer.

            ### At initialization, we:
            1. Store a subset of each MLP layer (e.g. 30%) and full attention layers on the GPU (similar to [LLM in a flash](https://arxiv.org/abs/2312.11514))
            2. Store full MLP layers in CPU RAM
            3. Store a cache of which neuron indices we currently have on the GPU

            ### Before each decoder layer's foward pass, we:
            1. Predict active MLP neurons based on the attention layer input (following [Deja Vu](https://proceedings.mlr.press/v202/liu23am/liu23am.pdf))

            ### During each decoder layer's attention computation, we, asynchronously on the CPU:
            1. Compute the difference between the set of predicted active neuron indices and the set of neuron indices we currently have on the GPU
            2. Index those neurons from CPU RAM
            3. Copy them to the GPU
            4. Update the layer's neuron indices cache

            ### And finally, during each decoder layer's MLP computation, we:
            1. Concatenate the newly received neuron diff with our existing neurons
            2. Compute the MLP (**Note**: As long as fully-connected layer 1 and fully-connected layer 2 share the same neuron ordering, the full two layer computation is invariant with respect to neuron ordering.)
            4. Overwrite a subset of our neuron buffer with the diff (FIFO order)
            5. Delete the diff

            ## Limitations
            1. This is approximate inference. The active neuron predictors do not have perfect recall, leading to slight accuracy degradation. See the [Deja Vu paper](https://proceedings.mlr.press/v202/liu23am/liu23am.pdf) for an in depth evaluation.

            ## Potential Improvements
            1. Evaluations--to push the sparsity levels, we need evaluations to measure accuracy degradation.
            2. Indexing the non-contiguous neuron diff from CPU RAM comes nowhere near saturating CPU-RAM memory bandwidth. We may be able to improve this with a custom C++ indexer.
            3. Early layers are extremely sparse while later layers are less sparse--perhaps we can allocate smaller GPU neuron buffers to early layers to free up space for larger buffers for later layers.
            4. Applying an advanced index to a pinned tensor in PyTorch will return an unpinned copy of the indexed data, which means it needs to be recopied to pinned memory before it can be sent to the GPU. If we can override this default PyTorch behavior to allow direct CPU-GPU copying from a specified advanced index without intermediate copies, we should get a nice speedup.
        ''')

iface.queue(max_size=5).launch(show_api=False)