namanbnsl commited on
Commit
be6911a
1 Parent(s): 479d404

Update modelling_custom.py

Browse files
Files changed (1) hide show
  1. modelling_custom.py +351 -1
modelling_custom.py CHANGED
@@ -1,7 +1,357 @@
1
  import torch
2
  from transformers import PreTrainedModel
3
- from model import Moose
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class MooseModel(PreTrainedModel):
6
  def __init__(self, config):
7
  super().__init__(config)
 
1
  import torch
2
  from transformers import PreTrainedModel
3
+ from dataclasses import dataclass
4
 
5
+ @dataclass
6
+ class ModelArgs:
7
+ dim: int = 768
8
+ n_layers: int = 16
9
+ n_heads: int = 16
10
+ n_kv_heads: Optional[int] = 4
11
+ vocab_size: int = 50304
12
+ multiple_of: int = 256
13
+ ffn_dim_multiplier: Optional[float] = None
14
+ norm_eps: float = 1e-5
15
+ rope_theta: float = 50000
16
+ max_batch_size: int = 4
17
+ max_seq_len: int = 1024
18
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ dropout_rate: float = 0.1
20
+
21
+ params = ModelArgs()
22
+
23
+ class RMSNorm(torch.nn.Module):
24
+ def __init__(self, dim: int, eps: float = 1e-6):
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+
29
+ def _norm(self, x):
30
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
31
+
32
+ def forward(self, x):
33
+ output = self._norm(x.float()).type_as(x)
34
+ return output * self.weight
35
+
36
+
37
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
38
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
39
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
40
+ freqs = torch.outer(t, freqs)
41
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
42
+ return freqs_cis.to(params.device)
43
+
44
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
45
+ ndim = x.ndim
46
+ assert 0 <= 1 < ndim
47
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != (x.shape[1], x.shape[-1]) {(x.shape[1], x.shape[-1])}'
48
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
49
+ return freqs_cis.view(*shape)
50
+
51
+ def apply_rotary_emb(
52
+ xq: torch.Tensor,
53
+ xk: torch.Tensor,
54
+ freqs_cis: torch.Tensor,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
57
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
58
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
59
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
60
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
61
+ return xq_out.type_as(xq), xk_out.type_as(xk)
62
+
63
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
64
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
65
+ bs, seqlen, n_kv_heads, head_dim = x.shape
66
+ if n_rep == 1:
67
+ return x
68
+ return (
69
+ x[:, :, :, None, :]
70
+ .expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
71
+ .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
72
+ )
73
+
74
+ class Attention(nn.Module):
75
+ def __init__(self, args: ModelArgs):
76
+ super().__init__()
77
+ self.n_heads = args.n_heads
78
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
79
+ self.n_rep = args.n_heads // self.n_kv_heads
80
+ self.head_dim = args.dim // args.n_heads
81
+
82
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
83
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
84
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
85
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
86
+
87
+ self.cache_k = torch.zeros(
88
+ (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
89
+ requires_grad = False
90
+ ).to(args.device)
91
+ self.cache_v = torch.zeros(
92
+ (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
93
+ requires_grad = False
94
+ ).to(args.device)
95
+
96
+ def forward(
97
+ self,
98
+ x: torch.Tensor,
99
+ freqs_cis: torch.Tensor,
100
+ mask: Optional[torch.Tensor],
101
+ start_pos: int = None,
102
+ ):
103
+ bsz, seqlen, _ = x.shape
104
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
105
+
106
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
107
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
108
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
109
+
110
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
111
+
112
+ if start_pos is not None: # if we're performing inference, use kv caching
113
+ self.cache_k = self.cache_k.to(xq)
114
+ self.cache_v = self.cache_v.to(xq)
115
+
116
+ # set the values in our cache according to the current input
117
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
118
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
119
+
120
+ # grab our key and value matrixes which have a longer sequence length than our queries
121
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
122
+ values = self.cache_v[:bsz, : start_pos + seqlen]
123
+ else:
124
+ # if we're training, do full sequence length
125
+ keys, values = xk, xv
126
+
127
+ keys = repeat_kv(keys, self.n_rep)
128
+ values = repeat_kv(values, self.n_rep)
129
+
130
+ xq = xq.transpose(1, 2)
131
+ keys = keys.transpose(1, 2)
132
+ values = values.transpose(1, 2)
133
+
134
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
135
+ if mask is not None:
136
+ scores = scores + mask
137
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
138
+
139
+ output = torch.matmul(scores, values)
140
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
141
+ return self.wo(output)
142
+
143
+ class FeedForward(nn.Module):
144
+ def __init__(
145
+ self,
146
+ dim: int,
147
+ hidden_dim: int,
148
+ multiple_of: int,
149
+ ffn_dim_multiplier: Optional[float],
150
+ ):
151
+ super().__init__()
152
+ # custom dim factor multiplier that ensures we're using a multiple of "multiple_of" for hardware efficiency reasons
153
+ hidden_dim = int(2 * hidden_dim / 3)
154
+ if ffn_dim_multiplier is not None:
155
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
156
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
157
+
158
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
159
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
160
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
161
+
162
+ def forward(self, x):
163
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
164
+
165
+ class TransformerBlock(nn.Module):
166
+ def __init__(self, args: ModelArgs):
167
+ super().__init__()
168
+ self.n_heads = args.n_heads
169
+ self.dim = args.dim
170
+ self.head_dim = args.dim // args.n_heads
171
+ self.attention = Attention(args)
172
+ self.feed_forward = FeedForward(
173
+ dim=args.dim,
174
+ hidden_dim=4 * args.dim,
175
+ multiple_of=args.multiple_of,
176
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
177
+ )
178
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
179
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
180
+ self.dropout_rate = args.dropout_rate
181
+
182
+ def forward(
183
+ self,
184
+ x: torch.Tensor,
185
+ freqs_cis: torch.Tensor,
186
+ mask: Optional[torch.Tensor],
187
+ start_pos: int = None,
188
+ training = False,
189
+ ):
190
+ h = x + F.dropout(self.attention(self.attention_norm(x), freqs_cis, mask, start_pos), p=self.dropout_rate, training=training)
191
+ out = h + F.dropout(self.feed_forward(self.ffn_norm(h)), p=self.dropout_rate, training=training)
192
+ return out
193
+
194
+ class Moose(nn.Module):
195
+ def __init__(self, params: ModelArgs):
196
+ super().__init__()
197
+ self.params = params
198
+ self.vocab_size = params.vocab_size
199
+ self.n_layers = params.n_layers
200
+ self.max_seq_len = params.max_seq_len
201
+
202
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
203
+
204
+ self.layers = torch.nn.ModuleList()
205
+ for _ in range(params.n_layers):
206
+ self.layers.append(TransformerBlock(params))
207
+
208
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
209
+ self.output = nn.Linear(
210
+ params.dim,
211
+ params.vocab_size,
212
+ bias=False)
213
+
214
+ self.freqs_cis = precompute_freqs_cis(
215
+ params.dim // params.n_heads,
216
+ params.max_seq_len * 2,
217
+ params.rope_theta,)
218
+
219
+ # precompute the causal attention mask
220
+ mask = torch.full((params.max_seq_len, params.max_seq_len),
221
+ float("-inf"),
222
+ device=params.device)
223
+ mask = torch.triu(mask, diagonal=1)
224
+ self.register_buffer('mask', mask)
225
+
226
+ self.criterion = nn.CrossEntropyLoss()
227
+
228
+ def forward(self,
229
+ tokens: torch.Tensor,
230
+ targets: torch.Tensor):
231
+ bsz, seqlen = tokens.shape
232
+ assert tokens.shape == targets.shape
233
+ assert seqlen == self.max_seq_len
234
+
235
+ h = self.tok_embeddings(tokens)
236
+
237
+ freqs_cis = self.freqs_cis.to(h.device)
238
+ freqs_cis = self.freqs_cis[:seqlen]
239
+
240
+ for layer in self.layers:
241
+ h = layer(
242
+ h,
243
+ freqs_cis,
244
+ self.mask,
245
+ start_pos = None,
246
+ training = True
247
+ )
248
+
249
+ h = self.norm(h)
250
+ logits = self.output(h).float()
251
+
252
+ loss = self.criterion(
253
+ logits.view(bsz * seqlen, self.vocab_size),
254
+ targets.reshape(bsz * seqlen))
255
+
256
+ return logits, loss
257
+
258
+ @torch.inference_mode()
259
+ def forward_inference(self,
260
+ tokens: torch.Tensor,
261
+ start_pos: int,
262
+ max_context_window: int,
263
+ ):
264
+ _bsz, seqlen = tokens.shape
265
+ h = self.tok_embeddings(tokens)
266
+ self.freqs_cis = self.freqs_cis.to(h.device)
267
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
268
+
269
+ mask = self.mask[:seqlen, :seqlen]
270
+ mask = torch.hstack(
271
+ [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
272
+ ).type_as(h)
273
+
274
+ for layer in self.layers:
275
+ h = layer(
276
+ h,
277
+ freqs_cis,
278
+ mask,
279
+ start_pos = start_pos
280
+ )
281
+ h = self.norm(h)
282
+ logits = self.output(h).float()
283
+ return logits
284
+
285
+ @torch.inference_mode()
286
+ def Sampler(
287
+ self,
288
+ logits: torch.Tensor,
289
+ temperature: float,
290
+ top_p: float
291
+ ) -> torch.Tensor:
292
+ logits = logits[:,-1,:]
293
+
294
+ logits.div_(temperature)
295
+
296
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
297
+
298
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
299
+
300
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
301
+ top_ps_mask = (probs_sum - probs_sort) > top_p
302
+ probs_sort = torch.where(top_ps_mask, 0, probs_sort)
303
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
304
+ probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
305
+ next_token_id = torch.multinomial(probs, num_samples=1)
306
+
307
+ return next_token_id
308
+
309
+ @torch.inference_mode()
310
+ def generate(
311
+ self,
312
+ prompt: str,
313
+ max_gen_len: int = 100,
314
+ memory_saver_div: int = 1,
315
+ temperature: float = 0.6,
316
+ top_p: float = 0.9,
317
+ ) -> str:
318
+ assert ((memory_saver_div & (memory_saver_div-1)) == 0) & (memory_saver_div > 0), f'memory_saver_div {memory_saver_div} must be power of 2'
319
+ max_context_window = self.max_seq_len // memory_saver_div
320
+
321
+ enc = tiktoken.get_encoding('gpt2')
322
+ tokens = enc.encode(prompt)
323
+
324
+ tokens = torch.tensor(tokens, device=self.params.device)
325
+ tokens = tokens.unsqueeze(0) if len(tokens.shape)==1 else tokens
326
+
327
+ start_pos = max(tokens.shape[1] - max_context_window, 0)
328
+ eot = enc._special_tokens['<|endoftext|>'] # end of text token
329
+
330
+ while True:
331
+ logits = self.forward_inference(
332
+ tokens[:,-max_context_window:],
333
+ start_pos = start_pos,
334
+ max_context_window = max_context_window
335
+ )
336
+
337
+ next_token = self.Sampler(
338
+ logits = logits,
339
+ temperature = temperature,
340
+ top_p = top_p,
341
+ )
342
+
343
+ tokens = torch.cat((tokens, next_token), dim=1)
344
+
345
+ if next_token.item() == eot:
346
+ break
347
+
348
+ if tokens.shape[1] >= max_context_window:
349
+ start_pos += 1
350
+
351
+ output = enc.decode(tokens.squeeze(0).tolist())
352
+
353
+ return output
354
+
355
  class MooseModel(PreTrainedModel):
356
  def __init__(self, config):
357
  super().__init__(config)