chansung commited on
Commit
018c6d9
·
1 Parent(s): c32e1f6

Update llama/model.py

Browse files
Files changed (1) hide show
  1. llama/model.py +114 -42
llama/model.py CHANGED
@@ -1,20 +1,18 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
 
4
- from typing import Optional, Tuple
 
 
5
  from dataclasses import dataclass
6
  import math
7
 
8
  import torch
9
  from torch import nn
10
  import torch.nn.functional as F
 
11
 
12
- import fairscale.nn.model_parallel.initialize as fs_init
13
- from fairscale.nn.model_parallel.layers import (
14
- ParallelEmbedding,
15
- RowParallelLinear,
16
- ColumnParallelLinear,
17
- )
18
 
19
 
20
  @dataclass
@@ -73,40 +71,57 @@ def apply_rotary_emb(
73
  return xq_out.type_as(xq), xk_out.type_as(xk)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class Attention(nn.Module):
77
  def __init__(self, args: ModelArgs):
78
  super().__init__()
79
 
80
- self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
 
 
81
  self.head_dim = args.dim // args.n_heads
82
 
83
- self.wq = ColumnParallelLinear(
 
84
  args.dim,
85
  args.n_heads * self.head_dim,
86
  bias=False,
87
- gather_output=False,
88
- init_method=lambda x: x,
89
  )
90
- self.wk = ColumnParallelLinear(
91
  args.dim,
92
  args.n_heads * self.head_dim,
93
  bias=False,
94
- gather_output=False,
95
- init_method=lambda x: x,
96
  )
97
- self.wv = ColumnParallelLinear(
98
  args.dim,
99
  args.n_heads * self.head_dim,
100
  bias=False,
101
- gather_output=False,
102
- init_method=lambda x: x,
103
  )
104
- self.wo = RowParallelLinear(
105
- args.n_heads * self.head_dim,
106
  args.dim,
 
107
  bias=False,
108
- input_is_parallel=True,
109
- init_method=lambda x: x,
110
  )
111
 
112
  self.cache_k = torch.zeros(
@@ -116,7 +131,13 @@ class Attention(nn.Module):
116
  (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
117
  ).cuda()
118
 
119
- def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
 
 
 
 
 
 
120
  bsz, seqlen, _ = x.shape
121
  xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
122
 
@@ -143,9 +164,7 @@ class Attention(nn.Module):
143
  scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
144
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
145
  output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
146
- output = output.transpose(
147
- 1, 2
148
- ).contiguous().view(bsz, seqlen, -1)
149
 
150
  return self.wo(output)
151
 
@@ -161,14 +180,17 @@ class FeedForward(nn.Module):
161
  hidden_dim = int(2 * hidden_dim / 3)
162
  hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
163
 
164
- self.w1 = ColumnParallelLinear(
165
- dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
166
- )
167
- self.w2 = RowParallelLinear(
168
- hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
 
169
  )
170
- self.w3 = ColumnParallelLinear(
171
- dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
 
 
172
  )
173
 
174
  def forward(self, x):
@@ -189,12 +211,36 @@ class TransformerBlock(nn.Module):
189
  self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
190
  self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
191
 
192
- def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
193
- h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
 
 
 
 
 
 
 
 
194
  out = h + self.feed_forward.forward(self.ffn_norm(h))
195
  return out
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class Transformer(nn.Module):
199
  def __init__(self, params: ModelArgs):
200
  super().__init__()
@@ -202,18 +248,16 @@ class Transformer(nn.Module):
202
  self.vocab_size = params.vocab_size
203
  self.n_layers = params.n_layers
204
 
205
- self.tok_embeddings = ParallelEmbedding(
206
- params.vocab_size, params.dim, init_method=lambda x: x
207
- )
208
 
209
  self.layers = torch.nn.ModuleList()
210
  for layer_id in range(params.n_layers):
211
  self.layers.append(TransformerBlock(layer_id, params))
212
 
213
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
214
- self.output = ColumnParallelLinear(
215
- params.dim, params.vocab_size, bias=False, init_method=lambda x: x
216
- )
217
 
218
  self.freqs_cis = precompute_freqs_cis(
219
  self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
@@ -228,11 +272,39 @@ class Transformer(nn.Module):
228
 
229
  mask = None
230
  if seqlen > 1:
231
- mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
 
 
232
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
233
 
234
  for layer in self.layers:
235
  h = layer(h, start_pos, freqs_cis, mask)
236
  h = self.norm(h)
237
  output = self.output(h[:, -1, :]) # only compute last logits
238
- return output.float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
 
4
+ from contextvars import ContextVar
5
+
6
+ from typing import Optional, Tuple, Type
7
  from dataclasses import dataclass
8
  import math
9
 
10
  import torch
11
  from torch import nn
12
  import torch.nn.functional as F
13
+ import bitsandbytes as bnb
14
 
15
+ import tqdm
 
 
 
 
 
16
 
17
 
18
  @dataclass
 
71
  return xq_out.type_as(xq), xk_out.type_as(xk)
72
 
73
 
74
+ class UninitializedLinear(nn.Linear):
75
+ def reset_parameters(self) -> None:
76
+ pass
77
+
78
+
79
+ class InferenceQuantizedLinear(bnb.nn.Linear8bitLt):
80
+ def __init__(self, *args, **kwargs):
81
+ super().__init__(has_fp16_weights=False, *args, **kwargs)
82
+
83
+ def reset_parameters(self) -> None:
84
+ pass
85
+
86
+
87
+ default_quantize: ContextVar[bool] = ContextVar("default_quantize", default=False)
88
+
89
+
90
+ def get_linear_class() -> Type[nn.Linear]:
91
+ if default_quantize.get():
92
+ return InferenceQuantizedLinear
93
+ return UninitializedLinear
94
+
95
+
96
  class Attention(nn.Module):
97
  def __init__(self, args: ModelArgs):
98
  super().__init__()
99
 
100
+ self.n_local_heads = (
101
+ args.n_heads // 1
102
+ ) # fs_init.get_model_parallel_world_size()
103
  self.head_dim = args.dim // args.n_heads
104
 
105
+ Linear = get_linear_class()
106
+ self.wq = Linear(
107
  args.dim,
108
  args.n_heads * self.head_dim,
109
  bias=False,
 
 
110
  )
111
+ self.wk = Linear(
112
  args.dim,
113
  args.n_heads * self.head_dim,
114
  bias=False,
 
 
115
  )
116
+ self.wv = Linear(
117
  args.dim,
118
  args.n_heads * self.head_dim,
119
  bias=False,
 
 
120
  )
121
+ self.wo = Linear(
 
122
  args.dim,
123
+ args.n_heads * self.head_dim,
124
  bias=False,
 
 
125
  )
126
 
127
  self.cache_k = torch.zeros(
 
131
  (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
132
  ).cuda()
133
 
134
+ def forward(
135
+ self,
136
+ x: torch.Tensor,
137
+ start_pos: int,
138
+ freqs_cis: torch.Tensor,
139
+ mask: Optional[torch.Tensor],
140
+ ):
141
  bsz, seqlen, _ = x.shape
142
  xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
143
 
 
164
  scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
165
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
166
  output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
167
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
 
 
168
 
169
  return self.wo(output)
170
 
 
180
  hidden_dim = int(2 * hidden_dim / 3)
181
  hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
182
 
183
+ Linear = get_linear_class()
184
+ self.w1 = Linear(dim, hidden_dim, bias=False)
185
+ self.w2 = Linear(
186
+ hidden_dim,
187
+ dim,
188
+ bias=False,
189
  )
190
+ self.w3 = Linear(
191
+ dim,
192
+ hidden_dim,
193
+ bias=False,
194
  )
195
 
196
  def forward(self, x):
 
211
  self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
212
  self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
213
 
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ start_pos: int,
218
+ freqs_cis: torch.Tensor,
219
+ mask: Optional[torch.Tensor],
220
+ ):
221
+ h = x + self.attention.forward(
222
+ self.attention_norm(x), start_pos, freqs_cis, mask
223
+ )
224
  out = h + self.feed_forward.forward(self.ffn_norm(h))
225
  return out
226
 
227
 
228
+ def convert_linear_to_bnb(float_linear):
229
+ new_layer = InferenceQuantizedLinear(
230
+ float_linear.in_features,
231
+ float_linear.out_features,
232
+ bias=float_linear.bias is not None,
233
+ )
234
+ new_layer._parameters["weight"] = bnb.nn.Int8Params(
235
+ float_linear.weight.data.cpu(),
236
+ requires_grad=False,
237
+ has_fp16_weights=False,
238
+ )
239
+ if float_linear.bias is not None:
240
+ new_layer._parameters["bias"] = float_linear.bias
241
+ return new_layer
242
+
243
+
244
  class Transformer(nn.Module):
245
  def __init__(self, params: ModelArgs):
246
  super().__init__()
 
248
  self.vocab_size = params.vocab_size
249
  self.n_layers = params.n_layers
250
 
251
+ self.tok_embeddings = torch.nn.Embedding(params.vocab_size, params.dim)
 
 
252
 
253
  self.layers = torch.nn.ModuleList()
254
  for layer_id in range(params.n_layers):
255
  self.layers.append(TransformerBlock(layer_id, params))
256
 
257
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
258
+
259
+ Linear = get_linear_class()
260
+ self.output = Linear(params.dim, params.vocab_size, bias=False)
261
 
262
  self.freqs_cis = precompute_freqs_cis(
263
  self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
 
272
 
273
  mask = None
274
  if seqlen > 1:
275
+ mask = torch.full(
276
+ (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
277
+ )
278
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
279
 
280
  for layer in self.layers:
281
  h = layer(h, start_pos, freqs_cis, mask)
282
  h = self.norm(h)
283
  output = self.output(h[:, -1, :]) # only compute last logits
284
+ return output.float()
285
+
286
+ def quantize(self):
287
+ # https://github.com/pytorch/vision/issues/2391#issuecomment-653900218
288
+ def get_layer(model, name):
289
+ layer = model
290
+ for attr in name.split("."):
291
+ layer = getattr(layer, attr)
292
+ return layer
293
+
294
+ def set_layer(model, name, layer):
295
+ try:
296
+ attrs, name = name.rsplit(".", 1)
297
+ model = get_layer(model, attrs)
298
+ except ValueError:
299
+ pass
300
+ setattr(model, name, layer)
301
+
302
+ linear_layers = {
303
+ k: v for k, v in self.named_modules() if isinstance(v, nn.Linear)
304
+ }
305
+
306
+ print("Quantizing", len(linear_layers), "layers")
307
+ for name, layer in tqdm.tqdm(linear_layers.items()):
308
+ new_layer = convert_linear_to_bnb(layer)
309
+ set_layer(self, name, new_layer)
310
+ self.cuda()