austinsilveria commited on
Commit
75fa479
1 Parent(s): b2a1f5e

Add tricksy

Browse files
Files changed (4) hide show
  1. app.py +47 -2
  2. configuration_tricksy.py +18 -0
  3. modeling_tricksy.py +618 -0
  4. util.py +83 -0
app.py CHANGED
@@ -1,4 +1,49 @@
 
 
1
  import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
  import streamlit as st
4
 
5
+ import torch
6
+ from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
7
+ from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
8
+ from configuration_tricksy import TricksyConfig
9
+
10
+ def generate():
11
+ set_seed(42)
12
+
13
+ # 13.4 GB (16 bit)
14
+ model_name = 'facebook/opt-6.7b'
15
+ disk_weights = OPTDiskWeights(model_name)
16
+ tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
19
+
20
+ inputs = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
21
+
22
+ print()
23
+ generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
24
+ thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
25
+ thread.start()
26
+ generated_text = ''
27
+ with st.chat_message("user"):
28
+ t = st.empty()
29
+ for new_text in streamer:
30
+ generated_text += new_text.replace('\n', ' \n')
31
+ t.write(generated_text)
32
+
33
+ stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
34
+ stats_text += f' \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
35
+ stats_text += f' \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'
36
+ st.write(stats_text)
37
+
38
+ prompt = st.text_area('Prompt', 'Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')
39
+
40
+ col1, col2 = st.columns(2)
41
+ with col1:
42
+ submit = st.button('Submit', on_click=generate)
43
+ with col2:
44
+ use_tricksy = st.toggle('Use Tricksy', True, help='If true, only send sparse MLP weight diffs to GPU. If false, send all weights to GPU.')
45
+
46
+ with st.expander('Additional options'):
47
+ max_new_tokens = st.slider('Max new tokens', 1, 500, 100)
48
+ top_k = st.slider('Top-k sampling', 1, 500, 50)
49
+ top_p = st.slider('Top-p (nucleus sampling)', 0.0, 1.0, .9)
configuration_tricksy.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from transformers.models.opt.configuration_opt import OPTConfig
4
+
5
+ @dataclasses.dataclass(frozen=True)
6
+ class TricksyConfig:
7
+ opt_config: OPTConfig
8
+
9
+ # Percentage of weights to keep on each device
10
+ # e.g. 30% of each MLP layer on GPU
11
+ min_mlp_sparsity_gpu: float = .3
12
+ # e.g. 100% of each MLP layer on CPU
13
+ min_mlp_sparsity_cpu: float = 1
14
+
15
+ # If true, cleans up layer's weights after computing forward pass
16
+ full_offload: bool = False
17
+
18
+ dtype: torch.dtype = torch.float16
modeling_tricksy.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Callable, List, Tuple
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from accelerate import init_empty_weights
11
+ from transformers.activations import ACT2FN
12
+ from transformers.generation import GenerationConfig
13
+ from transformers.models.opt.modeling_opt import (
14
+ OPTAttention,
15
+ OPTDecoder,
16
+ OPTDecoderLayer,
17
+ OPTForCausalLM,
18
+ OPTModel,
19
+ )
20
+ from transformers.models.opt.configuration_opt import OPTConfig
21
+ from huggingface_hub import snapshot_download
22
+
23
+ from configuration_tricksy import TricksyConfig
24
+ from util import batch_copy, compute_index_diffs, load_mlp_sparsity_predictor, mmap_to_tensor, topk_and_threshold
25
+
26
+ TRICKSY_WEIGHTS_PATH = 'tricksy-weights/'
27
+
28
+ class SparseMLPCache:
29
+ def __init__(
30
+ self,
31
+ indexed_fc1_weight: Optional[torch.Tensor] = None,
32
+ indexed_fc1_bias: Optional[torch.Tensor] = None,
33
+ indexed_fc2_weight: Optional[torch.Tensor] = None,
34
+ gpu_cached_mlp_indices: Optional[torch.Tensor] = None,
35
+ ):
36
+ # [ffn_embed_dim * min_mlp_sparsity, hidden_size]
37
+ self.indexed_fc1_weight = indexed_fc1_weight
38
+ # [ffn_embed_dim * min_mlp_sparsity]
39
+ self.indexed_fc1_bias = indexed_fc1_bias
40
+ # [ffn_embed_dim * min_mlp_sparsity, hidden_size] (stored in transpose for efficient indexing)
41
+ self.indexed_fc2_weight = indexed_fc2_weight
42
+
43
+ # Indices that are already on GPU (this tensor is stored on the CPU)
44
+ # [ffn_embed_dim * min_mlp_sparsity]
45
+ self.gpu_cached_mlp_indices = gpu_cached_mlp_indices
46
+
47
+ class SparseIndices:
48
+ def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig):
49
+ self.mlp_indices_buffer_gpu = torch.empty(
50
+ (int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),),
51
+ dtype=torch.int32,
52
+ device='cuda'
53
+ )
54
+ self.mlp_indices_buffer_cpu = torch.empty(
55
+ (int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),),
56
+ dtype=torch.int32,
57
+ device='cpu',
58
+ pin_memory=True,
59
+ )
60
+
61
+ # Default stream blocks until indices are copied to CPU
62
+ self.index_copy_stream = torch.cuda.default_stream()
63
+
64
+ def copy_mlp_indices_to_cpu(self):
65
+ self.mlp_indices_buffer_cpu = batch_copy([self.mlp_indices_buffer_gpu], self.index_copy_stream, device='cpu')[0]
66
+
67
+ class OPTDiskWeights:
68
+ def __init__(self, model_name: str):
69
+ self.model_name = model_name
70
+ self.model_suffix = model_name.split('/')[-1]
71
+ self.config = OPTConfig.from_pretrained(model_name)
72
+
73
+ try:
74
+ print(f'downloading from austinsilveria/tricksy-{self.model_suffix}')
75
+ self.weight_path = snapshot_download(repo_id=f'austinsilveria/tricksy-{self.model_suffix}') + '/'
76
+ except:
77
+ print(f'failed to download from austinsilveria/tricksy-{self.model_suffix}')
78
+ self.weight_path = f'{TRICKSY_WEIGHTS_PATH}{self.model_suffix}/'
79
+
80
+ with init_empty_weights():
81
+ model = OPTModel(self.config)
82
+ self.state_dict = model.state_dict()
83
+
84
+ if not os.path.exists(f'{self.weight_path}decoder.embed_tokens.weight'):
85
+ # Download original weights and write memmap files
86
+ print(f'downloading and preprocessing original weights')
87
+ self.cache_weights()
88
+
89
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
90
+ for i in range(self.config.num_hidden_layers):
91
+ layer_prefix = f'decoder.layers.{i}.'
92
+ self.delete_weights([
93
+ f'{layer_prefix}self_attn.q_proj.weight',
94
+ f'{layer_prefix}self_attn.k_proj.weight',
95
+ f'{layer_prefix}self_attn.v_proj.weight',
96
+ f'{layer_prefix}self_attn.out_proj.weight',
97
+ f'{layer_prefix}self_attn.q_proj.bias',
98
+ f'{layer_prefix}self_attn.k_proj.bias',
99
+ f'{layer_prefix}self_attn.v_proj.bias'
100
+ ])
101
+ self.add_weights([
102
+ (f'{layer_prefix}fc2.weight', (self.config.ffn_dim, self.config.hidden_size)),
103
+ (f'{layer_prefix}self_attn.catted_head_weights', (self.config.num_attention_heads, head_dim * 4, self.config.hidden_size)),
104
+ (f'{layer_prefix}self_attn.catted_head_biases', (self.config.num_attention_heads, 3, head_dim)),
105
+ ])
106
+
107
+ self.memmap_weights = { key: self.load_memmap_weight(key) for key in self.state_dict.keys() }
108
+
109
+ def load_memmap_weight(self, key: str):
110
+ return torch.from_numpy(np.memmap(f'{self.weight_path}{key}', dtype='float16', mode='r', shape=(self.state_dict[key].shape)))
111
+
112
+ def add_weights(self, weights: List[Tuple[str, torch.Size]]):
113
+ for key, shape in weights:
114
+ self.state_dict[key] = torch.empty(shape, dtype=torch.float16, device='meta')
115
+
116
+ def delete_weights(self, keys: List[str]):
117
+ for key in keys:
118
+ if key in self.state_dict:
119
+ del self.state_dict[key]
120
+ path = f'{self.weight_path}{key}'
121
+ if os.path.exists(path):
122
+ os.remove(path)
123
+
124
+ def cache_weights(self):
125
+ os.makedirs(self.weight_path, exist_ok=True)
126
+ weights_location = snapshot_download(repo_id=self.model_name, ignore_patterns=['flax*', 'tf*'])
127
+ shards = [file for file in os.listdir(weights_location) if file.startswith("pytorch_model") and file.endswith(".bin")]
128
+ for shard in shards:
129
+ print(f'caching {shard}')
130
+ shard_path = os.path.join(weights_location, shard)
131
+ shard_state_dict = torch.load(shard_path)
132
+ for key in shard_state_dict.keys():
133
+ path = f'{self.weight_path}{key.replace("model.", "")}'
134
+ memmap = np.memmap(path, dtype='float16', mode='w+', shape=(shard_state_dict[key].shape))
135
+ memmap[:] = shard_state_dict[key].cpu().numpy()
136
+
137
+ # Store weights in shape for efficient indexing
138
+ for i in range(self.config.num_hidden_layers):
139
+ layer_prefix = f'decoder.layers.{i}.'
140
+ # FC2 in transpose
141
+ fc2t = torch.from_numpy(np.array(self.load_memmap_weight(f'{layer_prefix}fc2.weight')[:])).t().contiguous().clone()
142
+ np.memmap(f'{self.weight_path}decoder.layers.{i}.fc2.weight', dtype='float16', mode='w+', shape=fc2t.shape)[:] = fc2t.numpy()
143
+
144
+ # Attention weights by head
145
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
146
+ qw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.weight')[:])
147
+ kw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.weight')[:])
148
+ vw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.weight')[:])
149
+ ow = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.out_proj.weight')[:])
150
+ pre_cat_shape = (self.config.num_attention_heads, head_dim, self.config.hidden_size)
151
+ # [head, head_dim * 4, hidden_size]
152
+ catted_head_weights = torch.cat(
153
+ [qw.view(pre_cat_shape).clone(), kw.view(pre_cat_shape).clone(), vw.view(pre_cat_shape).clone(), ow.T.view(pre_cat_shape).clone(),],
154
+ dim=1,
155
+ ).contiguous().clone()
156
+ np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_weights', dtype='float16', mode='w+', shape=catted_head_weights.shape)[:] =\
157
+ catted_head_weights.numpy()
158
+
159
+ # Attention biases by head
160
+ qb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.bias')[:])
161
+ kb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.bias')[:])
162
+ vb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.bias')[:])
163
+ pre_cat_shape = (self.config.num_attention_heads, 1, head_dim)
164
+ # [head, 3, head_dim]
165
+ catted_head_biases = torch.cat(
166
+ # Don't index out bias since we need all dims after projecting back up to hidden size
167
+ [qb.view(pre_cat_shape).clone(), kb.view(pre_cat_shape).clone(), vb.view(pre_cat_shape).clone()],
168
+ dim=1,
169
+ ).contiguous().clone()
170
+ np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_biases', dtype='float16', mode='w+', shape=catted_head_biases.shape)[:] =\
171
+ catted_head_biases.numpy()
172
+
173
+ self.delete_weights([
174
+ f'{layer_prefix}self_attn.q_proj.weight',
175
+ f'{layer_prefix}self_attn.k_proj.weight',
176
+ f'{layer_prefix}self_attn.v_proj.weight',
177
+ f'{layer_prefix}self_attn.out_proj.weight',
178
+ f'{layer_prefix}self_attn.q_proj.bias',
179
+ f'{layer_prefix}self_attn.k_proj.bias',
180
+ f'{layer_prefix}self_attn.v_proj.bias'
181
+ ])
182
+ self.add_weights([
183
+ (f'{layer_prefix}self_attn.catted_head_weights', catted_head_weights.shape),
184
+ (f'{layer_prefix}self_attn.catted_head_biases', catted_head_biases.shape),
185
+ ])
186
+
187
+ class TricksyContext:
188
+ def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig):
189
+ self.indices = SparseIndices(tricksy_config, opt_config)
190
+ self.load_weight_stream = torch.cuda.Stream()
191
+ self.layer = 0
192
+ self.is_prompt_phase = True
193
+ self.forward_times = []
194
+
195
+ class TricksyLayer:
196
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
197
+ return self.forward(*args, **kwds)
198
+
199
+ def load_weights(self, tricksy_context: TricksyContext):
200
+ pass
201
+
202
+ class TricksyLayerInputs:
203
+ def __init__(
204
+ self,
205
+ disk_weights: OPTDiskWeights,
206
+ layer_key_prefix: str = None,
207
+ next_layer: TricksyLayer = None,
208
+ sparsity_predictors: List[Callable[[torch.Tensor], torch.Tensor]] = None,
209
+ ) -> None:
210
+ self.disk_weights = disk_weights
211
+ # self.get_weight = lambda key: self.disk_weights.load_memmap_weight(f'{layer_key_prefix}{key}')
212
+ self.get_weight = lambda key: self.disk_weights.memmap_weights[(f'{layer_key_prefix}{key}')]
213
+ self.layer_key_prefix = layer_key_prefix
214
+ self.next_layer = next_layer
215
+ self.sparsity_predictors = sparsity_predictors
216
+
217
+ class TricksyOPTLearnedPositionalEmbedding(TricksyLayer):
218
+ """
219
+ This module learns positional embeddings up to a fixed maximum size.
220
+ """
221
+
222
+ def __init__(self, tricksy_context):
223
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
224
+ # and adjust num_embeddings appropriately. Other models don't have this hack
225
+ self.offset = 2
226
+ self.tricksy_context = tricksy_context
227
+ self.weight = None
228
+
229
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
230
+ return self.forward(*args, **kwds)
231
+
232
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
233
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
234
+ attention_mask = attention_mask.long()
235
+ # create positions depending on attention_mask
236
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
237
+ # cut positions if `past_key_values_length` is > 0
238
+ positions = positions[:, past_key_values_length:]
239
+
240
+ out = F.embedding(positions + self.offset, self.weight)
241
+ return out
242
+
243
+ class TricksyOPTAttention(OPTAttention, TricksyLayer):
244
+ def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext, is_decoder: bool = False, **kwargs):
245
+ nn.Module.__init__(self)
246
+ self.tricksy_config = tricksy_config
247
+ self.config = tricksy_config.opt_config
248
+
249
+ def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
250
+ """
251
+ If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
252
+ warning and return that value, otherwise take the equivalent config.config_arg_name
253
+ """
254
+ val = None
255
+ if fn_arg_name in kwargs:
256
+ print(
257
+ "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
258
+ " Please set it in the config instead"
259
+ )
260
+ val = kwargs.pop(fn_arg_name)
261
+ else:
262
+ val = getattr(config, config_arg_name)
263
+ return val
264
+
265
+ self.embed_dim = _handle_deprecated_argument("hidden_size", self.config, "embed_dim", kwargs)
266
+ self.num_heads = _handle_deprecated_argument("num_attention_heads", self.config, "num_heads", kwargs)
267
+ self.dropout = _handle_deprecated_argument("attention_dropout", self.config, "dropout", kwargs)
268
+ self.enable_bias = _handle_deprecated_argument("enable_bias", self.config, "bias", kwargs)
269
+
270
+ self.head_dim = self.embed_dim // self.num_heads
271
+ self.is_causal = True
272
+
273
+ if (self.head_dim * self.num_heads) != self.embed_dim:
274
+ raise ValueError(
275
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
276
+ f" and `num_heads`: {self.num_heads})."
277
+ )
278
+ self.scaling = self.head_dim**-0.5
279
+ self.is_decoder = is_decoder
280
+
281
+ # [Tricksy]
282
+ self.tricksy_context = tricksy_context
283
+ self.inputs = inputs
284
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
285
+
286
+ self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None
287
+ self.q_proj = lambda x: F.linear(x, self.qw, self.qb)
288
+ self.k_proj = lambda x: F.linear(x, self.kw, self.kb)
289
+ self.v_proj = lambda x: F.linear(x, self.vw, self.vb)
290
+ self.out_proj = lambda x: F.linear(x, self.ow, self.out_proj_bias)
291
+ self.layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.layer_norm_weight, self.layer_norm_bias)
292
+
293
+ def load_weights(self, tricksy_context: TricksyContext):
294
+ if self.tricksy_context.is_prompt_phase:
295
+ # Full weights for prompt phase
296
+ self.catted_weights, self.catted_biases, self.out_proj_bias, self.layer_norm_weight, self.layer_norm_bias = batch_copy(
297
+ [
298
+ mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_weights')[:], pin_memory=True),
299
+ mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_biases')[:], pin_memory=True),
300
+ mmap_to_tensor(self.inputs.get_weight('self_attn.out_proj.bias')[:], pin_memory=True),
301
+ mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.weight')[:], pin_memory=True),
302
+ mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.bias')[:], pin_memory=True),
303
+ ],
304
+ tricksy_context.load_weight_stream,
305
+ )
306
+ torch.cuda.synchronize()
307
+ # Weights stored in shape for efficient indexing to support offloading attention heads (not currently being done)
308
+ self.qw = self.catted_weights[:, :self.head_dim, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
309
+ self.kw = self.catted_weights[:, self.head_dim:self.head_dim * 2, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
310
+ self.vw = self.catted_weights[:, self.head_dim * 2:self.head_dim * 3, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
311
+ self.ow = self.catted_weights[:, self.head_dim * 3:, :].reshape(self.config.hidden_size, self.config.hidden_size).t().contiguous()
312
+ self.catted_weights = None
313
+
314
+ self.qb = self.catted_biases[:, 0, :].reshape(self.config.hidden_size).contiguous()
315
+ self.kb = self.catted_biases[:, 1, :].reshape(self.config.hidden_size).contiguous()
316
+ self.vb = self.catted_biases[:, 2, :].reshape(self.config.hidden_size).contiguous()
317
+ self.catted_biases = None
318
+
319
+ def forward(self, hidden_states, **kwargs):
320
+ # Wait for attention weights to get to GPU
321
+ torch.cuda.synchronize()
322
+
323
+ # Predict MLP sparsity based on attention input
324
+ self.tricksy_context.indices.mlp_indices_buffer_gpu = topk_and_threshold(
325
+ self.inputs.sparsity_predictors[0](hidden_states)[0, -1, :],
326
+ int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu),
327
+ )
328
+ self.tricksy_context.indices.copy_mlp_indices_to_cpu()
329
+ torch.cuda.synchronize()
330
+
331
+ # Load MLP weights while computing attention
332
+ self.inputs.next_layer.load_weights(self.tricksy_context)
333
+
334
+ out = super().forward(self.layer_norm(hidden_states), **kwargs)
335
+
336
+ # Wait for MLP weights to get to GPU
337
+ torch.cuda.synchronize()
338
+
339
+ return out
340
+
341
+ class TricksyOPTDecoderLayer(OPTDecoderLayer):
342
+ def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext):
343
+ nn.Module.__init__(self)
344
+ self.tricksy_config = tricksy_config
345
+ self.config = tricksy_config.opt_config
346
+ self.embed_dim = self.config.hidden_size
347
+
348
+ self.tricksy_context = tricksy_context
349
+ self.self_attn_layer_inputs = TricksyLayerInputs(
350
+ disk_weights=inputs.disk_weights,
351
+ layer_key_prefix=inputs.layer_key_prefix,
352
+ # While computing attention, load MLP
353
+ next_layer=self,
354
+ sparsity_predictors=inputs.sparsity_predictors,
355
+ )
356
+ self.self_attn = TricksyOPTAttention(tricksy_config, self.self_attn_layer_inputs, tricksy_context, is_decoder=True)
357
+
358
+ self.do_layer_norm_before = self.config.do_layer_norm_before
359
+ self.dropout = self.config.dropout
360
+ self.activation_fn = ACT2FN[self.config.activation_function]
361
+
362
+ self.inputs = inputs
363
+ random_mlp_indices_gpu =\
364
+ torch.randperm(self.config.ffn_dim, device='cpu', dtype=torch.int32)[:int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu)]
365
+ self.index_cache = SparseMLPCache(gpu_cached_mlp_indices=random_mlp_indices_gpu)
366
+
367
+ # identity since we move this to attention layer
368
+ # extreme tricksy
369
+ self.self_attn_layer_norm = lambda x: x
370
+
371
+ self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None
372
+ self.ring_idx = 0
373
+ self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None
374
+ self.fc1 = lambda x: F.linear(x, torch.cat([self.fc1_weight, self.fc1_weight_diff]), torch.cat([self.fc1_bias, self.fc1_bias_diff]))
375
+ self.fc2 = lambda x: F.linear(x, torch.cat([self.fc2_weight, self.fc2_weight_diff]).T, self.fc2_bias)
376
+ self.final_layer_norm = lambda x: F.layer_norm(x, (self.embed_dim,), self.final_layer_norm_weight, self.final_layer_norm_bias)
377
+
378
+ def load_weights(self, tricksy_context: TricksyContext):
379
+ if self.tricksy_context.is_prompt_phase:
380
+ # Full weights for prompt phase
381
+ fc1w = mmap_to_tensor(self.inputs.get_weight('fc1.weight')[:], pin_memory=True)
382
+ fc1b = mmap_to_tensor(self.inputs.get_weight('fc1.bias')[:], pin_memory=True)
383
+ fc2w = mmap_to_tensor(self.inputs.get_weight('fc2.weight')[:], pin_memory=True)
384
+ fc2b = mmap_to_tensor(self.inputs.get_weight('fc2.bias')[:], pin_memory=True)
385
+ lnw = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True)
386
+ lnb = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True)
387
+
388
+ self.fc1_weight, self.fc1_bias, self.fc2_weight, self.fc2_bias, self.final_layer_norm_weight, self.final_layer_norm_bias =\
389
+ batch_copy([fc1w, fc1b, fc2w, fc2b, lnw, lnb], tricksy_context.load_weight_stream)
390
+ self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
391
+ self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
392
+ self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
393
+
394
+ index_diffs = compute_index_diffs(tricksy_context.indices.mlp_indices_buffer_cpu, [self.index_cache.gpu_cached_mlp_indices])
395
+ if len(index_diffs) > 0:
396
+ gpu_index_diff = index_diffs[0]
397
+ self.index_cache.gpu_cached_mlp_indices[gpu_index_diff.off_positions] = gpu_index_diff.off_elements
398
+
399
+ self.index_cache.indexed_fc1_weight = fc1w.contiguous().pin_memory()
400
+ self.index_cache.indexed_fc1_bias = fc1b.contiguous().pin_memory()
401
+ self.index_cache.indexed_fc2_weight = fc2w.contiguous().pin_memory()
402
+ return
403
+ elif self.fc1_weight is None:
404
+ # Full weights if full offload
405
+ self.fc1_weight, self.fc1_bias, self.fc2_weight = batch_copy(
406
+ [self.index_cache.indexed_fc1_weight, self.index_cache.indexed_fc1_bias, self.index_cache.indexed_fc2_weight],
407
+ tricksy_context.load_weight_stream
408
+ )
409
+ self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
410
+ self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
411
+ self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
412
+
413
+ off_elements = torch.tensor(
414
+ list(set(tricksy_context.indices.mlp_indices_buffer_cpu.tolist()).difference(set(self.index_cache.gpu_cached_mlp_indices.tolist()))),
415
+ device='cpu',
416
+ dtype=torch.int32,
417
+ pin_memory=True
418
+ )
419
+ if off_elements.size(0) == 0:
420
+ self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
421
+ self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
422
+ self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
423
+ return
424
+
425
+ new_ring_idx = (self.ring_idx + off_elements.size(0)) % self.index_cache.gpu_cached_mlp_indices.size(0)
426
+ if new_ring_idx > self.ring_idx:
427
+ # single contiguous update
428
+ self.index_cache.gpu_cached_mlp_indices[self.ring_idx:new_ring_idx] = off_elements
429
+ elif off_elements.size(0) > 0:
430
+ split = self.index_cache.gpu_cached_mlp_indices.size(0) - self.ring_idx
431
+ # end of ring
432
+ self.index_cache.gpu_cached_mlp_indices[self.ring_idx:] = off_elements[:split]
433
+ # beginning of ring
434
+ self.index_cache.gpu_cached_mlp_indices[:new_ring_idx] = off_elements[split:]
435
+
436
+ # Allocate
437
+ self.fc1_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda')
438
+ self.fc1_bias_diff = torch.empty((off_elements.size(0)), dtype=self.tricksy_config.dtype, device='cuda')
439
+ self.fc2_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda')
440
+ # Index
441
+ fc1wd = self.index_cache.indexed_fc1_weight[off_elements].pin_memory()
442
+ fc1bd = self.index_cache.indexed_fc1_bias[off_elements].pin_memory()
443
+ fc2wd = self.index_cache.indexed_fc2_weight[off_elements].pin_memory()
444
+ # Copy
445
+ self.fc1_weight_diff, self.fc1_bias_diff, self.fc2_weight_diff = batch_copy([fc1wd, fc1bd, fc2wd], tricksy_context.load_weight_stream)
446
+
447
+ def forward(self, *args, **kwargs):
448
+ # Wait for attention weights to get to GPU
449
+ torch.cuda.synchronize()
450
+
451
+ # Load next layer's attention weights
452
+ self.inputs.next_layer.load_weights(self.tricksy_context)
453
+
454
+ out = super().forward(*args, **kwargs)
455
+
456
+ if self.tricksy_config.full_offload:
457
+ self.fc1_weight = self.fc1_bias = self.fc2_weight = None
458
+ elif self.tricksy_context.is_prompt_phase:
459
+ # Only keep sparse MLP weights on GPU after prompt phase
460
+ self.fc1_weight = self.fc1_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
461
+ self.fc1_bias = self.fc1_bias[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
462
+ self.fc2_weight = self.fc2_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
463
+
464
+ # Update ring buffers
465
+ if not self.tricksy_config.full_offload:
466
+ prev_ring_idx = self.ring_idx
467
+ self.ring_idx = (self.ring_idx + self.fc1_weight_diff.size(0)) % self.fc1_weight.size(0)
468
+ if self.ring_idx > prev_ring_idx:
469
+ # does not wrap around ring
470
+ self.fc1_weight[prev_ring_idx:self.ring_idx] = self.fc1_weight_diff
471
+ self.fc1_bias[prev_ring_idx:self.ring_idx] = self.fc1_bias_diff
472
+ self.fc2_weight[prev_ring_idx:self.ring_idx] = self.fc2_weight_diff
473
+ elif self.fc1_weight_diff.size(0) > 0:
474
+ # wraps around ring
475
+ split = self.fc1_weight_diff.size(0) - self.ring_idx
476
+ self.fc1_weight[prev_ring_idx:] = self.fc1_weight_diff[:split]
477
+ self.fc1_weight[:self.ring_idx] = self.fc1_weight_diff[split:]
478
+ self.fc1_bias[prev_ring_idx:] = self.fc1_bias_diff[:split]
479
+ self.fc1_bias[:self.ring_idx] = self.fc1_bias_diff[split:]
480
+ self.fc2_weight[prev_ring_idx:] = self.fc2_weight_diff[:split]
481
+ self.fc2_weight[:self.ring_idx] = self.fc2_weight_diff[split:]
482
+ self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None
483
+
484
+ self.tricksy_context.layer += 1
485
+ return out
486
+
487
+ class TricksyOPTDecoder(OPTDecoder, TricksyLayer):
488
+ def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext):
489
+ nn.Module.__init__(self)
490
+ self.config = tricksy_config.opt_config
491
+ self.dropout = self.config.dropout
492
+ self.layerdrop = self.config.layerdrop
493
+ self.padding_idx = self.config.pad_token_id
494
+ self.max_target_positions = self.config.max_position_embeddings
495
+ self.vocab_size = self.config.vocab_size
496
+ self._use_flash_attention_2 = False
497
+ self.gradient_checkpointing = False
498
+ self.project_out = None
499
+ self.project_in = None
500
+
501
+ self.embed_tokens_weight = None
502
+ self.embed_positions = TricksyOPTLearnedPositionalEmbedding(tricksy_context)
503
+
504
+ self.tricksy_context = tricksy_context
505
+ self.layers: List[TricksyOPTDecoderLayer] = []
506
+ for i in range(self.config.num_hidden_layers):
507
+ pretrained_layer_num = self.config.num_hidden_layers - i - 1
508
+ sparsity_predictors = [load_mlp_sparsity_predictor(disk_weights.weight_path, pretrained_layer_num, tricksy_config.dtype)]
509
+ if sparsity_predictors[0] is None:
510
+ sparsity_predictors[0] = lambda x: F.linear(x, torch.rand((self.config.ffn_dim, self.config.hidden_size), device='cuda', dtype=tricksy_config.dtype))
511
+ self.layers.append(TricksyOPTDecoderLayer(
512
+ tricksy_config,
513
+ TricksyLayerInputs(
514
+ disk_weights=disk_weights,
515
+ layer_key_prefix=f'decoder.layers.{pretrained_layer_num}.',
516
+ # While computing MLP, load next attention
517
+ # While computing last MLP, load output embeddings (stored in TricksyOPTForCausalLM)
518
+ next_layer=self.layers[i - 1].self_attn if i > 0 else tricksy_opt_for_causal_lm,
519
+ sparsity_predictors=sparsity_predictors,
520
+ ),
521
+ tricksy_context,
522
+ ))
523
+ self.layers.reverse()
524
+
525
+ self.final_layer_norm = lambda x: x
526
+ self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.')
527
+
528
+ def embed_tokens(self, x):
529
+ return F.embedding(x, self.embed_tokens_weight, self.padding_idx)
530
+
531
+ def load_weights(self, tricksy_context: TricksyContext):
532
+ if self.embed_tokens_weight is None:
533
+ self.embed_tokens_weight, self.embed_positions.weight = batch_copy(
534
+ [
535
+ mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True),
536
+ mmap_to_tensor(self.inputs.get_weight('embed_positions.weight')[:], pin_memory=True),
537
+ ],
538
+ tricksy_context.load_weight_stream,
539
+ )
540
+
541
+ def forward(self, *args, **kwargs):
542
+ # Wait for input embedding weights to get to GPU
543
+ torch.cuda.synchronize()
544
+
545
+ # While computing input embeddings, load first attention
546
+ self.layers[0].self_attn.load_weights(self.tricksy_context)
547
+
548
+ out = super().forward(*args, **kwargs)
549
+
550
+ # Wait for output embedding weights to get to GPU
551
+ torch.cuda.synchronize()
552
+
553
+ # No longer prompt phase after first full pass
554
+ self.tricksy_context.is_prompt_phase = False
555
+ # Load input embeddings while computing output
556
+ self.load_weights(self.tricksy_context)
557
+
558
+ return out
559
+
560
+ class TricksyOPTModel(OPTModel):
561
+ def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext):
562
+ nn.Module.__init__(self)
563
+ self.config = tricksy_config.opt_config
564
+ self.tricksy_context = tricksy_context
565
+ self.decoder = TricksyOPTDecoder(tricksy_config, disk_weights, tricksy_opt_for_causal_lm, tricksy_context)
566
+
567
+ def forward(self, *args, **kwargs):
568
+ out = super().forward(*args, **kwargs)
569
+ return out
570
+
571
+ # who's got the weights?
572
+ # [InputEmbedding, Attention.0, MLP.0, Attention.1, MLP.1, ..., OutputEmbedding]
573
+ # [TricksyOPTDecoder, TricksyOPTAttention.0, TricksyOPTDecoderLayer.0, TricksyOPTAttention.1, TricksyDecoderLayer.1, ..., TricksyOPTForCausalLM]
574
+ #
575
+ # 1. Prompt pass: Before computing layer, send full dense weights to GPU. After computing layer, only keep sparse weights on GPU.
576
+ # 2. Generation passes: Before computing layer, compute and send sparse weight diff to GPU.
577
+ class TricksyOPTForCausalLM(OPTForCausalLM, TricksyLayer):
578
+ def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights):
579
+ nn.Module.__init__(self)
580
+ self.config = disk_weights.config
581
+ self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
582
+
583
+ self.tricksy_context = TricksyContext(tricksy_config, self.config)
584
+ self.model = TricksyOPTModel(tricksy_config, disk_weights, self, self.tricksy_context)
585
+
586
+ self.final_layer_norm_weight = self.lm_head_weight = self.final_layer_norm_bias = None
587
+ # double stacking tricksy!
588
+ self.final_layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.final_layer_norm_weight, self.final_layer_norm_bias)
589
+ self.lm_head = lambda x: F.linear(self.final_layer_norm(x), self.lm_head_weight)
590
+
591
+ self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.', next_layer=self.model.decoder)
592
+
593
+ def load_weights(self, tricksy_context: TricksyContext):
594
+ if self.final_layer_norm_weight is None:
595
+ self.final_layer_norm_weight, self.lm_head_weight, self.final_layer_norm_bias = batch_copy(
596
+ [
597
+ mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True),
598
+ mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True),
599
+ mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True),
600
+ ],
601
+ tricksy_context.load_weight_stream,
602
+ )
603
+
604
+ def forward(self, *args, **kwargs):
605
+ torch.cuda.synchronize()
606
+ start = time.time()
607
+ out = super().forward(*args, **kwargs)
608
+ torch.cuda.synchronize()
609
+ self.tricksy_context.forward_times.append(time.time() - start)
610
+ self.tricksy_context.layer = 0
611
+ return out
612
+
613
+ def generate(self, *args, **kwargs):
614
+ # Load input embeddings for first token
615
+ self.model.decoder.load_weights(self.tricksy_context)
616
+ torch.cuda.synchronize()
617
+ out = super().generate(*args, **kwargs)
618
+ return out
util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ np_dtype_to_torch_dtype = {
9
+ np.float16: torch.float16,
10
+ np.float32: torch.float32,
11
+ np.uint8: torch.uint8,
12
+ np.int8: torch.int8,
13
+ np.int32: torch.int32,
14
+ np.int64: torch.int64,
15
+ bool: torch.bool,
16
+ }
17
+
18
+ class IndexDiff:
19
+ def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None):
20
+ self.off_elements = off_elements
21
+ self.off_positions = off_positions
22
+ self.on_positions = on_positions
23
+
24
+ def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'):
25
+ with torch.cuda.stream(copy_stream):
26
+ out = ()
27
+ for src in sources:
28
+ indexed = src[indices] if indices is not None else src
29
+ dst = torch.empty(indexed.shape, device=device, dtype=src.dtype)
30
+ dst.copy_(indexed, non_blocking=True)
31
+ out += (dst,)
32
+ return out
33
+
34
+ def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor:
35
+ out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory)
36
+ out.copy_(torch_wrapped_mmap)
37
+ return out
38
+
39
+ # Assuming that each entry of cached_indices is a step down the memory hierarchy,
40
+ # compute the diff at each level of the hierarchy.
41
+ # e.g. the first loop computes the indices that the GPU does not have,
42
+ # and the second loop computes the indices *of that diff* that the CPU does not have.
43
+ def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True):
44
+ diffs = []
45
+ current_diff = new_indices
46
+ for cached_indices in cached_indices_list:
47
+ if current_diff.size(0) == 0:
48
+ # No need to go further down the hierarchy
49
+ break
50
+
51
+ # Compute elements of new indices not contained current indices
52
+ off_elements = torch.tensor(
53
+ list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))),
54
+ device='cpu',
55
+ dtype=torch.int32,
56
+ pin_memory=pin_memory
57
+ )
58
+ # Compute mask of current indices where new indices does not contain the element
59
+ on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True)
60
+ on_positions = torch.nonzero(on_position_mask).flatten()
61
+ off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)]
62
+
63
+ diffs.append(IndexDiff(off_elements, off_positions, on_positions))
64
+ current_diff = off_elements
65
+ return diffs
66
+
67
+ def topk_and_threshold(x, k, threshold=1):
68
+ vals, indices = torch.topk(x, k, sorted=True)
69
+ return indices[vals > threshold].int()
70
+
71
+ def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable:
72
+ path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.'
73
+ return load_predictor(path_prefix, dtype, device=device)
74
+
75
+ def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable:
76
+ path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight')
77
+ if os.path.exists(path(1)):
78
+ l1 = torch.load(path(1)).to(device).to(dtype)
79
+ l2 = torch.load(path(2)).to(device).to(dtype)
80
+ return lambda x: F.linear(F.linear(x, l1), l2)
81
+ else:
82
+ print(f'could not find predictor at {path(1)}')
83
+ return None