BucketOfFish commited on
Commit
a420fe7
1 Parent(s): 41127ee

Just uploading entire rewritten codebase at once

Browse files
Files changed (6) hide show
  1. attention.py +522 -0
  2. config.json +30 -27
  3. configuration_phi.py +0 -56
  4. modeling_phi.py +0 -766
  5. phi2_configuration.py +60 -0
  6. phi2_model.py +166 -0
attention.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from einops import rearrange, repeat
3
+ import math
4
+ import torch
5
+ from torch.amp.autocast_mode import autocast
6
+ import torch.nn as nn
7
+ from transformers.activations import ACT2FN
8
+ from typing import cast
9
+
10
+ # if flash_attn exists
11
+ try:
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
14
+ from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
15
+ from flash_attn.ops.fused_dense import FusedDense
16
+ except ImportError:
17
+ print("flash_attn not found, using default implementations")
18
+ pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None
19
+
20
+
21
+ class RotaryEmbedding(nn.Module):
22
+ """Rotary positional embedding (RoPE) from Phi2.
23
+ See https://www.youtube.com/watch?v=C6rV8BsrrCc
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ d_rotary: int,
29
+ rotary_base: float = 10000.0,
30
+ initial_cos_sin_cache_len: int = 2048,
31
+ device: torch.device | None = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.d_rotary = d_rotary
35
+ self.rotary_base = rotary_base
36
+ self.device = device
37
+ self.dtype = torch.float32
38
+ self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
39
+
40
+ def _update_cos_sin_cache(self, seqlen: int) -> None:
41
+ # only call this function when seqlen is larger than _max_seqlen
42
+ self._max_seqlen = seqlen
43
+
44
+ # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
45
+ m = torch.arange(
46
+ seqlen,
47
+ device=self.device,
48
+ dtype=self.dtype,
49
+ )
50
+ theta_i = 1.0 / (
51
+ self.rotary_base ** (
52
+ torch.arange(
53
+ start=0,
54
+ end=self.d_rotary,
55
+ step=2,
56
+ device=self.device,
57
+ dtype=self.dtype,
58
+ ) / self.d_rotary
59
+ )
60
+ )
61
+ # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
62
+ # TODO: does this matter if I'm disabling torch.autocast?
63
+ m_theta_i = torch.outer(m, theta_i)
64
+ self._cos_cached = torch.cos(m_theta_i).to(self.dtype)
65
+ self._sin_cached = torch.sin(m_theta_i).to(self.dtype)
66
+
67
+ # TODO: scale_base caching is labelled as not yet done in Phi2
68
+ """
69
+ if scale_base is not None:
70
+ scale = (
71
+ torch.arange(
72
+ start=0,
73
+ end=self.d_rotary,
74
+ step=2,
75
+ device=self.device,
76
+ dtype=torch.float32,
77
+ ) + 0.4 * self.d_rotary
78
+ ) / (1.4 * self.d_rotary)
79
+ power = (
80
+ torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2
81
+ ) / scale_base
82
+ scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1")
83
+ self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype)
84
+ self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype)
85
+ """
86
+
87
+ def _apply_rotary_emb_qkv(
88
+ self,
89
+ x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
90
+ cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
91
+ sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
92
+ ) -> torch.FloatTensor:
93
+ seqlen = x.shape[1]
94
+ x1, x2 = x.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head/2)
95
+ broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
96
+ c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
97
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
98
+ return cast(
99
+ torch.FloatTensor,
100
+ torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
101
+ )
102
+
103
+ def forward(
104
+ self,
105
+ x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
106
+ seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache
107
+ ) -> torch.FloatTensor:
108
+ if (
109
+ not self._max_seqlen
110
+ or self._max_seqlen < x.shape[1] + seqlen_offset
111
+ or self._cos_cached.device != x.device
112
+ or self._cos_cached.dtype != x.dtype
113
+ or (self.training and self._cos_cached.is_inference())
114
+ ):
115
+ self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
116
+ return self._apply_rotary_emb_qkv(
117
+ x,
118
+ cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
119
+ cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]),
120
+ )
121
+
122
+
123
+ class SelfAttention(nn.Module):
124
+ """Self-attention layer, taken from Phi2 model."""
125
+
126
+ def __init__(
127
+ self,
128
+ qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
129
+ attention_dropout: float = 0.0,
130
+ ) -> None:
131
+ super().__init__()
132
+ self.qk_scale = qk_scale
133
+ self.dropout = nn.Dropout(attention_dropout)
134
+
135
+ # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow
136
+ @autocast("cpu", enabled=False)
137
+ @autocast("cuda", enabled=False)
138
+ def forward(
139
+ self,
140
+ qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
141
+ causal: bool = True,
142
+ key_padding_mask: torch.BoolTensor | None = None,
143
+ ) -> torch.FloatTensor:
144
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
145
+ q, k, v = qkv.unbind(dim=2)
146
+ q = q.to(torch.float32)
147
+ k = k.to(torch.float32)
148
+ qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1])
149
+
150
+ scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale)
151
+
152
+ if key_padding_mask:
153
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
154
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
155
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
156
+
157
+ if causal:
158
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
159
+ scores = scores + causal_mask.to(dtype=scores.dtype)
160
+
161
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
162
+ attention = self.dropout(attention)
163
+
164
+ output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen, n_heads, d_head)
165
+ return cast(torch.FloatTensor, output)
166
+
167
+
168
+ class CrossAttention(nn.Module):
169
+ """Cross-attention layer, taken from Phi2 model."""
170
+
171
+ def __init__(
172
+ self,
173
+ qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
174
+ attention_dropout: float = 0.0,
175
+ ) -> None:
176
+ super().__init__()
177
+ self.qk_scale = qk_scale
178
+ self.dropout = nn.Dropout(attention_dropout)
179
+
180
+ # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow
181
+ @autocast("cpu", enabled=False)
182
+ @autocast("cuda", enabled=False)
183
+ def forward(
184
+ self,
185
+ q: torch.FloatTensor, # dim: (batch_size, seqlen_q, n_heads, d_head)
186
+ kv: torch.FloatTensor, # dim: (batch_size, seqlen_kv, 2, n_heads, d_head)
187
+ causal: bool = True,
188
+ key_padding_mask: torch.BoolTensor | None = None,
189
+ ) -> torch.FloatTensor:
190
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
191
+ seqlen_k = kv.shape[1]
192
+ if kv.shape[3] != q.shape[2]: # repeat kv n_heads dim to match q n_heads
193
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
194
+ k, v = kv.unbind(dim=2)
195
+ q = cast(torch.FloatTensor, q.to(torch.float32))
196
+ k = k.to(torch.float32)
197
+ qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1])
198
+
199
+ scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale)
200
+
201
+ if key_padding_mask:
202
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
203
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
204
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
205
+
206
+ if causal:
207
+ rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
208
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
209
+ causal_mask = cols > rows + seqlen_k - seqlen_q
210
+ scores = scores.masked_fill(causal_mask, -10000.0)
211
+
212
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
213
+ attention = self.dropout(attention)
214
+
215
+ output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen_q, n_heads, d_head)
216
+ return cast(torch.FloatTensor, output)
217
+
218
+
219
+ class MLP(nn.Module):
220
+ """Taken from Phi2 as well."""
221
+
222
+ def __init__(
223
+ self,
224
+ d_embedding: int,
225
+ act_fn: str = "gelu_new",
226
+ ) -> None:
227
+ super().__init__()
228
+ n_inner = 4 * d_embedding
229
+ self.fc1 = nn.Linear(d_embedding, n_inner)
230
+ self.act = ACT2FN[act_fn]
231
+ self.fc2 = nn.Linear(n_inner, d_embedding)
232
+
233
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
234
+ x = self.fc1(x)
235
+ x = self.act(x)
236
+ x = self.fc2(x)
237
+ return x
238
+
239
+
240
+ @dataclass
241
+ class KVCache:
242
+ """Options for model to calculate and store context during inference."""
243
+ max_seqlen: int
244
+ max_batch_size: int
245
+ seqlen_offset: int
246
+ batch_size_offset: int
247
+ kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict)
248
+ lengths_per_sample: torch.Tensor | None = None
249
+
250
+
251
+ class MHA(nn.Module):
252
+ """Multi-head attention block."""
253
+
254
+ def __init__(
255
+ self,
256
+ d_embedding: int,
257
+ n_attn_heads: int,
258
+ block_n: int,
259
+ initial_cos_sin_cache_len: int, # length of cache for rotary embedding
260
+ attn_pdrop: float,
261
+ use_flash_rotary: bool, # use flash rotary embedding if possible
262
+ use_flash_attn: bool, # use flash attention if possible
263
+ use_fused_dense: bool, # use fused dense layer if possible
264
+ checkpointing: bool, # torch.utils.checkpoint
265
+ ) -> None:
266
+ super().__init__()
267
+
268
+ # rotary embedding
269
+ rotary_cls = (
270
+ FlashRotaryEmbedding
271
+ if use_flash_rotary and FlashRotaryEmbedding is not None
272
+ else RotaryEmbedding
273
+ )
274
+ self.rotary_emb = rotary_cls(
275
+ d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
276
+ initial_cos_sin_cache_len=initial_cos_sin_cache_len,
277
+ )
278
+
279
+ # self attention
280
+ self_attn_cls = (
281
+ FlashSelfAttention
282
+ if use_flash_attn and FlashSelfAttention is not None
283
+ else SelfAttention
284
+ )
285
+ self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop)
286
+
287
+ # cross attention
288
+ cross_attn_cls = (
289
+ FlashCrossAttention
290
+ if use_flash_attn and FlashCrossAttention is not None
291
+ else CrossAttention
292
+ )
293
+ self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop)
294
+
295
+ # MLP
296
+ self.n_attn_heads = n_attn_heads
297
+ self.d_head = d_embedding // n_attn_heads
298
+ linear_cls = (
299
+ FusedDense
300
+ if use_fused_dense and FusedDense is not None
301
+ else nn.Linear
302
+ )
303
+ self.Wqkv = linear_cls(
304
+ d_embedding,
305
+ self.d_head * (3 * self.n_attn_heads), # calculating q, k, v for all heads in block simultaneously
306
+ )
307
+ self.fc_out = linear_cls(d_embedding, d_embedding)
308
+
309
+ # settings
310
+ self.using_flash_attn = self_attn_cls is FlashSelfAttention
311
+ self.block_n = block_n
312
+ self.checkpointing = checkpointing
313
+
314
+ def _forward_self_attn(
315
+ self,
316
+ qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
317
+ key_padding_mask: torch.BoolTensor | None,
318
+ ) -> torch.FloatTensor:
319
+ qkv = cast(
320
+ torch.FloatTensor,
321
+ torch.cat(
322
+ [
323
+ self.rotary_emb(qkv[:, :, :2, :, :]), # qk
324
+ qkv[:, :, 2, :, :], # v
325
+ ],
326
+ dim=2,
327
+ )
328
+ )
329
+
330
+ if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
331
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
332
+ cu_seqlens, max_seqlen, indices = None, None, None
333
+
334
+ # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
335
+ if key_padding_mask:
336
+ qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
337
+
338
+ if self.checkpointing:
339
+ attn_output = torch.utils.checkpoint.checkpoint(
340
+ self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
341
+ )
342
+ else:
343
+ attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
344
+
345
+ # repad output
346
+ if key_padding_mask:
347
+ return pad_input(attn_output, indices, batch_size, seqlen)
348
+ else:
349
+ return attn_output
350
+
351
+ if self.checkpointing:
352
+ return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask)
353
+ else:
354
+ return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask)
355
+
356
+ def _update_kv_cache(
357
+ self,
358
+ kv: torch.FloatTensor, # dim: (batch_size, seqlen, 2, n_heads, d_head)
359
+ kv_cache: KVCache,
360
+ block_n: int,
361
+ ) -> None:
362
+ if block_n not in kv_cache.kv_block_map:
363
+ kv_cache.kv_block_map[block_n] = torch.empty(
364
+ kv_cache.max_batch_size,
365
+ kv_cache.max_seqlen,
366
+ 2,
367
+ kv.shape[-2], # n_heads
368
+ kv.shape[-1], # d_head
369
+ dtype=kv.dtype,
370
+ device=kv.device,
371
+ )
372
+ kv_cache.kv_block_map[block_n][
373
+ kv_cache.batch_size_offset: kv_cache.batch_size_offset + kv.shape[0],
374
+ kv_cache.seqlen_offset: kv_cache.seqlen_offset + kv.shape[1],
375
+ ...
376
+ ] = kv
377
+
378
+ def _forward_cross_attn(
379
+ self,
380
+ qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
381
+ kv_cache: KVCache,
382
+ key_padding_mask: torch.BoolTensor | None,
383
+ ) -> torch.FloatTensor:
384
+ q = qkv[:, :, 0, :, :]
385
+ q = self.rotary_emb(
386
+ q,
387
+ seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
388
+ )
389
+ kv = cast(torch.FloatTensor, qkv[:, :, 1:, :, :])
390
+ self._update_kv_cache(kv, kv_cache, self.block_n)
391
+ causal = False # turning off causal mask for cross attention
392
+
393
+ if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
394
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
395
+ seqlen_k = kv.shape[1]
396
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = (
397
+ None,
398
+ None,
399
+ None,
400
+ None,
401
+ None,
402
+ )
403
+
404
+ # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
405
+ if key_padding_mask:
406
+ kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
407
+
408
+ if seqlen_q == 1:
409
+ key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device))
410
+ elif seqlen_q != seqlen_k:
411
+ key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:])
412
+
413
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
414
+
415
+ if self.checkpointing:
416
+ attn_output = torch.utils.checkpoint.checkpoint(
417
+ self.inner_cross_attn,
418
+ q,
419
+ kv,
420
+ causal=causal,
421
+ cu_seqlens=cu_seqlens_q,
422
+ max_seqlen=max_seqlen_q,
423
+ cu_seqlens_k=cu_seqlens_k,
424
+ max_seqlen_k=max_seqlen_k,
425
+ )
426
+ else:
427
+ attn_output = self.inner_cross_attn(
428
+ q,
429
+ kv,
430
+ causal=causal,
431
+ cu_seqlens=cu_seqlens_q,
432
+ max_seqlen=max_seqlen_q,
433
+ cu_seqlens_k=cu_seqlens_k,
434
+ max_seqlen_k=max_seqlen_k,
435
+ )
436
+
437
+ if key_padding_mask:
438
+ return pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
439
+ else:
440
+ return attn_output
441
+
442
+ if self.checkpointing:
443
+ return torch.utils.checkpoint.checkpoint(
444
+ self.inner_cross_attn,
445
+ q,
446
+ kv,
447
+ key_padding_mask=key_padding_mask,
448
+ causal=causal,
449
+ )
450
+ else:
451
+ return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
452
+
453
+ def forward(
454
+ self,
455
+ x: torch.FloatTensor, # dim: (batch_size, seqlen, d_embedding)
456
+ kv_cache: KVCache | None = None,
457
+ key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
458
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
459
+ if key_padding_mask is not None:
460
+ key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) # make sure it's bool and not int
461
+
462
+ qkv = self.Wqkv(x) # dim: (batch_size, seqlen, 3*n_heads*d_head)
463
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) # dim: (batch_size, seqlen, 3, n_heads, d_head)
464
+ if kv_cache is None:
465
+ attn_output = self._forward_self_attn(qkv, key_padding_mask)
466
+ else:
467
+ attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask)
468
+
469
+ output = rearrange(attn_output, "... h d -> ... (h d)")
470
+ output = self.fc_out(output)
471
+ return output
472
+
473
+
474
+ class ParallelAttentionBlock(nn.Module):
475
+ """From Phi2. Calculates attention and MLP in parallel. See 'Simplifying Transformer Blocks', Fig. 1 'Parallel'."""
476
+
477
+ def __init__(
478
+ self,
479
+ resid_pdrop: float, # a bit of a misnomer, right?
480
+ layer_norm_epsilon: float,
481
+ d_embedding: int,
482
+ n_attn_heads: int,
483
+ block_n: int,
484
+ initial_cos_sin_cache_len: int, # length of cache for rotary embedding
485
+ attn_pdrop: float,
486
+ use_flash_rotary: bool = True, # use flash rotary embedding if possible
487
+ use_flash_attn: bool = True, # use flash attention if possible
488
+ use_fused_dense: bool = True, # use fused dense layer if possible
489
+ checkpointing: bool = False, # torch.utils.checkpoint
490
+ ) -> None:
491
+ super().__init__()
492
+ self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon)
493
+ self.block_n = block_n
494
+ self.multi_head_attention = MHA(
495
+ d_embedding=d_embedding,
496
+ n_attn_heads=n_attn_heads,
497
+ block_n=block_n,
498
+ initial_cos_sin_cache_len=initial_cos_sin_cache_len,
499
+ attn_pdrop=attn_pdrop,
500
+ use_flash_rotary=use_flash_rotary,
501
+ use_flash_attn=use_flash_attn,
502
+ use_fused_dense=use_fused_dense,
503
+ checkpointing=checkpointing,
504
+ )
505
+ self.mlp = MLP(d_embedding)
506
+ self.dropout = nn.Dropout(resid_pdrop)
507
+
508
+ def forward(
509
+ self,
510
+ x: torch.FloatTensor, # dim: (batch_size, seq_len, d_embedding)
511
+ kv_cache: KVCache | None = None,
512
+ key_padding_mask: torch.BoolTensor | None = None,
513
+ ) -> torch.FloatTensor:
514
+ residual = x
515
+ x = self.layer_norm(x) # each token (dim: d_embedding) is normalized individually
516
+ attn_outputs = self.multi_head_attention(
517
+ x,
518
+ kv_cache=kv_cache,
519
+ key_padding_mask=key_padding_mask,
520
+ )
521
+ mlp_outputs = self.mlp(x)
522
+ return self.dropout(attn_outputs + mlp_outputs) + residual
config.json CHANGED
@@ -1,29 +1,32 @@
1
  {
2
- "_name_or_path": "microsoft/phi-2",
3
- "activation_function": "gelu_new",
4
- "architectures": [
5
- "PhiForCausalLM"
6
- ],
7
- "attn_pdrop": 0.0,
8
- "auto_map": {
9
- "AutoConfig": "configuration_phi.PhiConfig",
10
- "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
11
- },
12
- "embd_pdrop": 0.0,
13
- "img_processor": null,
14
- "initializer_range": 0.02,
15
- "layer_norm_epsilon": 1e-05,
16
- "model_type": "phi-msft",
17
- "n_embd": 2560,
18
- "n_head": 32,
19
- "n_head_kv": null,
20
- "n_inner": null,
21
- "n_layer": 32,
22
- "n_positions": 2048,
23
- "resid_pdrop": 0.1,
24
- "rotary_dim": 32,
25
- "tie_word_embeddings": false,
26
- "torch_dtype": "float16",
27
- "transformers_version": "4.35.2",
28
- "vocab_size": 51200
 
 
 
29
  }
 
1
  {
2
+ "_name_or_path": "BucketOfFish/simplified_phi2",
3
+ "architectures": [
4
+ "Phi2Model",
5
+ "Phi2ModelForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "phi2_configuration.Phi2Config",
9
+ "AutoModel": "phi2_model.Phi2Model",
10
+ "AutoModelForCausalLM": "phi2_model.Phi2ModelForCausalLM"
11
+ },
12
+ "model_type": "phi2",
13
+ "torch_dtype": "float16",
14
+ "transformers_version": "4.29.0",
15
+
16
+ "vocab_size": 50304,
17
+ "vocab_chunk_for_gpu_efficiency": 64,
18
+ "initial_cos_sin_cache_len": 2048,
19
+ "d_embedding": 2560,
20
+ "n_blocks": 32,
21
+ "n_heads": 32,
22
+ "use_flash_attn": false,
23
+ "use_flash_rotary": false,
24
+ "use_fused_dense": false,
25
+ "attn_pdrop": 0.0,
26
+ "embd_pdrop": 0.0,
27
+ "resid_pdrop": 0.1,
28
+ "layer_norm_epsilon": 1e-05,
29
+ "weight_initialization_range": 0.02,
30
+ "tie_word_embeddings": false,
31
+ "checkpointing": false
32
  }
configuration_phi.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- import math
5
- from typing import Optional
6
-
7
- from transformers import PretrainedConfig
8
-
9
-
10
- class PhiConfig(PretrainedConfig):
11
- """Phi configuration."""
12
-
13
- model_type = "phi-msft"
14
- attribute_map = {
15
- "max_position_embeddings": "n_positions",
16
- "hidden_size": "n_embd",
17
- "num_attention_heads": "n_head",
18
- "num_hidden_layers": "n_layer",
19
- }
20
-
21
- def __init__(
22
- self,
23
- vocab_size: int = 50304,
24
- n_positions: int = 2048,
25
- n_embd: int = 1024,
26
- n_layer: int = 20,
27
- n_inner: Optional[int] = None,
28
- n_head: int = 16,
29
- n_head_kv: Optional[int] = None,
30
- rotary_dim: Optional[int] = 32,
31
- activation_function: Optional[str] = "gelu_new",
32
- attn_pdrop: float = 0.0,
33
- embd_pdrop: float = 0.0,
34
- resid_pdrop: float = 0.0,
35
- layer_norm_epsilon: float = 1e-5,
36
- initializer_range: float = 0.02,
37
- tie_word_embeddings: bool = False,
38
- pad_vocab_size_multiple: int = 64,
39
- **kwargs
40
- ) -> None:
41
- self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
42
- self.n_positions = n_positions
43
- self.n_embd = n_embd
44
- self.n_layer = n_layer
45
- self.n_inner = n_inner
46
- self.n_head = n_head
47
- self.n_head_kv = n_head_kv
48
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
49
- self.activation_function = activation_function
50
- self.attn_pdrop = attn_pdrop
51
- self.embd_pdrop = embd_pdrop
52
- self.resid_pdrop = resid_pdrop
53
- self.layer_norm_epsilon = layer_norm_epsilon
54
- self.initializer_range = initializer_range
55
-
56
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_phi.py DELETED
@@ -1,766 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
- #
4
- # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
5
- # Licensed under the BSD 3-Clause License.
6
-
7
- from __future__ import annotations
8
-
9
- import math
10
- from dataclasses import dataclass, field
11
- from typing import Any, Dict, Optional, Tuple, Union
12
-
13
- import torch
14
- import torch.nn as nn
15
- from einops import rearrange, repeat
16
- from transformers import PretrainedConfig, PreTrainedModel
17
- from transformers.activations import ACT2FN
18
- from transformers.modeling_outputs import CausalLMOutputWithPast
19
-
20
- from .configuration_phi import PhiConfig
21
-
22
-
23
- @dataclass
24
- class InferenceParams:
25
- """Inference parameters passed to model to efficiently calculate
26
- and store context during inference.
27
-
28
- Reference:
29
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
30
-
31
- Args:
32
- max_seqlen: Maximum sequence length.
33
- max_batch_size: Maximum batch size.
34
- seqlen_offset: Sequence length offset.
35
- batch_size_offset: Batch size offset.
36
- key_value_memory_dict: Key value memory dictionary.
37
- lengths_per_sample: Lengths per sample.
38
-
39
- """
40
-
41
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
42
-
43
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
44
-
45
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
46
-
47
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
48
-
49
- key_value_memory_dict: Dict[str, Any] = field(
50
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
51
- )
52
-
53
- lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
54
-
55
-
56
- class Embedding(nn.Module):
57
- """Token embedding with dropout."""
58
-
59
- def __init__(self, config: PretrainedConfig) -> None:
60
- super().__init__()
61
-
62
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
63
- self.drop = nn.Dropout(config.embd_pdrop)
64
-
65
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
66
- input_shape = input_ids.size()
67
- input_ids = input_ids.view(-1, input_shape[-1])
68
-
69
- hidden_states = self.wte(input_ids)
70
- hidden_states = self.drop(hidden_states)
71
-
72
- return hidden_states
73
-
74
-
75
- class RotaryEmbedding(nn.Module):
76
- """Rotary positional embedding (RoPE) from Phi2.
77
- See https://www.youtube.com/watch?v=C6rV8BsrrCc
78
- """
79
-
80
- def __init__(
81
- self,
82
- d_rotary: int,
83
- rotary_base: float = 10000.0,
84
- initial_cos_sin_cache_len: int = 2048,
85
- device: torch.device | None = None,
86
- ) -> None:
87
- super().__init__()
88
- self.d_rotary = d_rotary
89
- self.rotary_base = rotary_base
90
- self.device = device
91
- self.dtype = torch.float32
92
- self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
93
-
94
- def _update_cos_sin_cache(self, seqlen: int) -> None:
95
- # only call this function when seqlen is larger than _max_seqlen
96
- self._max_seqlen = seqlen
97
-
98
- # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
99
- m = torch.arange(
100
- seqlen,
101
- device=self.device,
102
- dtype=self.dtype,
103
- )
104
- theta_i = 1.0 / (
105
- self.rotary_base ** (
106
- torch.arange(
107
- start=0,
108
- end=self.d_rotary,
109
- step=2,
110
- device=self.device,
111
- dtype=self.dtype,
112
- ) / self.d_rotary
113
- )
114
- )
115
- # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
116
- # TODO: does this matter if I'm disabling torch.autocast?
117
- m_theta_i = torch.outer(m, theta_i)
118
- self._cos_cached = torch.cos(m_theta_i).to(self.dtype)
119
- self._sin_cached = torch.sin(m_theta_i).to(self.dtype)
120
-
121
- # TODO: scale_base caching is labelled as not yet done in Phi2
122
- """
123
- if scale_base is not None:
124
- scale = (
125
- torch.arange(
126
- start=0,
127
- end=self.d_rotary,
128
- step=2,
129
- device=self.device,
130
- dtype=torch.float32,
131
- ) + 0.4 * self.d_rotary
132
- ) / (1.4 * self.d_rotary)
133
- power = (
134
- torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2
135
- ) / scale_base
136
- scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1")
137
- self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype)
138
- self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype)
139
- """
140
-
141
- def _apply_rotary_emb_qkv(
142
- self,
143
- x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
144
- cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
145
- sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
146
- ) -> torch.FloatTensor:
147
- seqlen = x.shape[1]
148
- x1, x2 = x.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head/2)
149
- broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
150
- c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
151
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
152
- return cast(
153
- torch.FloatTensor,
154
- torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
155
- )
156
-
157
- def forward(
158
- self,
159
- x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
160
- seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache
161
- ) -> torch.FloatTensor:
162
- if (
163
- not self._max_seqlen
164
- or self._max_seqlen < x.shape[1] + seqlen_offset
165
- or self._cos_cached.device != x.device
166
- or self._cos_cached.dtype != x.dtype
167
- or (self.training and self._cos_cached.is_inference())
168
- ):
169
- self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
170
- return self._apply_rotary_emb_qkv(
171
- x,
172
- cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
173
- cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]),
174
- )
175
-
176
-
177
- class MLP(nn.Module):
178
- """Multi-Layer Perceptron.
179
-
180
- Reference:
181
- Attention Is All You Need.
182
- https://arxiv.org/pdf/1706.03762.pdf.
183
-
184
- """
185
-
186
- def __init__(
187
- self,
188
- config: PretrainedConfig,
189
- n_inner: Optional[int] = None,
190
- act_fn: Optional[str] = None,
191
- ) -> None:
192
- super().__init__()
193
-
194
- act_fn = config.activation_function if act_fn is None else act_fn
195
-
196
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
197
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
198
-
199
- self.fc1 = nn.Linear(config.n_embd, n_inner)
200
- self.fc2 = nn.Linear(n_inner, config.n_embd)
201
- self.act = ACT2FN[act_fn]
202
-
203
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
204
- hidden_states = self.fc1(hidden_states)
205
- hidden_states = self.act(hidden_states)
206
- hidden_states = self.fc2(hidden_states)
207
-
208
- return hidden_states
209
-
210
-
211
- class SelfAttention(nn.Module):
212
- """Self-attention layer (compatible with PyTorch).
213
-
214
- Reference:
215
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
216
-
217
- """
218
-
219
- def __init__(
220
- self,
221
- causal: bool = True,
222
- softmax_scale: Optional[float] = None,
223
- attention_dropout: float = 0.0,
224
- ) -> None:
225
- super().__init__()
226
-
227
- self.causal = causal
228
- self.softmax_scale = softmax_scale
229
- self.drop = nn.Dropout(attention_dropout)
230
-
231
- @torch.autocast("cpu", enabled=False)
232
- @torch.autocast("cuda", enabled=False)
233
- def forward(
234
- self,
235
- qkv: torch.FloatTensor,
236
- causal: bool = None,
237
- key_padding_mask: Optional[torch.BoolTensor] = None,
238
- **kwargs,
239
- ) -> torch.FloatTensor:
240
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
241
- q, k, v = qkv.unbind(dim=2)
242
-
243
- q = q.to(torch.float32)
244
- k = k.to(torch.float32)
245
-
246
- causal = self.causal if causal is None else causal
247
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
248
-
249
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
250
- # using float16, which might lead to overflow
251
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
252
-
253
- if key_padding_mask is not None:
254
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
255
- padding_mask.masked_fill_(key_padding_mask, 0.0)
256
-
257
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
258
-
259
- if causal:
260
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
261
- scores = scores + causal_mask.to(dtype=scores.dtype)
262
-
263
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
264
- attention = self.drop(attention)
265
-
266
- output = torch.einsum("bhts,bshd->bthd", attention, v)
267
-
268
- return output
269
-
270
-
271
- class CrossAttention(nn.Module):
272
- """Cross-attention layer (compatible with PyTorch).
273
-
274
- Reference:
275
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
276
-
277
- """
278
-
279
- def __init__(
280
- self,
281
- causal: bool = True,
282
- softmax_scale: Optional[float] = None,
283
- attention_dropout: float = 0.0,
284
- ) -> None:
285
- super().__init__()
286
-
287
- self.causal = causal
288
- self.softmax_scale = softmax_scale
289
- self.drop = nn.Dropout(attention_dropout)
290
-
291
- @torch.autocast("cpu", enabled=False)
292
- @torch.autocast("cuda", enabled=False)
293
- def forward(
294
- self,
295
- q: torch.FloatTensor,
296
- kv: torch.FloatTensor,
297
- causal: bool = None,
298
- key_padding_mask: Optional[torch.BoolTensor] = None,
299
- **kwargs,
300
- ) -> torch.FloatTensor:
301
- batch_size, seqlen_q = q.shape[0], q.shape[1]
302
- seqlen_k = kv.shape[1]
303
-
304
- if kv.shape[3] != q.shape[2]:
305
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
306
- k, v = kv.unbind(dim=2)
307
-
308
- q = q.to(torch.float32)
309
- k = k.to(torch.float32)
310
-
311
- causal = self.causal if causal is None else causal
312
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
313
-
314
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
315
- # using float16, which might lead to overflow
316
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
317
-
318
- if key_padding_mask is not None:
319
- padding_mask = torch.full(
320
- (batch_size, seqlen_k),
321
- -10000.0,
322
- dtype=scores.dtype,
323
- device=scores.device,
324
- )
325
- padding_mask.masked_fill_(key_padding_mask, 0.0)
326
-
327
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
328
-
329
- if causal:
330
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
331
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
332
- causal_mask = cols > rows + seqlen_k - seqlen_q
333
-
334
- scores = scores.masked_fill(causal_mask, -10000.0)
335
-
336
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
337
- attention = self.drop(attention)
338
-
339
- output = torch.einsum("bhts,bshd->bthd", attention, v)
340
-
341
- return output
342
-
343
-
344
- def _find_mha_dims(
345
- config: PretrainedConfig,
346
- n_head: Optional[int] = None,
347
- n_head_kv: Optional[int] = None,
348
- head_dim: Optional[int] = None,
349
- ) -> Tuple[int, int]:
350
- if n_head is None and head_dim is None:
351
- head_dim = config.n_embd // config.n_head
352
- n_head = config.n_head
353
- elif n_head is None or head_dim is None:
354
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
355
-
356
- if n_head_kv is None:
357
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
358
-
359
- return n_head, n_head_kv, head_dim
360
-
361
-
362
- def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
363
- num_heads, head_dim = kv.shape[-2:]
364
-
365
- if layer_idx not in inference_params.key_value_memory_dict:
366
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
367
- inference_params.max_batch_size,
368
- inference_params.max_seqlen,
369
- 2,
370
- num_heads,
371
- head_dim,
372
- dtype=kv.dtype,
373
- device=kv.device,
374
- )
375
-
376
- batch_start = inference_params.batch_size_offset
377
- batch_end = batch_start + kv.shape[0]
378
-
379
- sequence_start = inference_params.seqlen_offset
380
- sequence_end = sequence_start + kv.shape[1]
381
-
382
- # When the current sequence length is equal to or larger than the maximum sequence length,
383
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
384
- if sequence_end >= inference_params.max_seqlen:
385
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
386
-
387
- inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
388
- kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
389
-
390
- return kv
391
-
392
-
393
- class MHA(nn.Module):
394
- """Multi-head attention layer."""
395
-
396
- def __init__(
397
- self,
398
- config: PretrainedConfig,
399
- dtype: Optional[torch.dtype] = None,
400
- device: Optional[str] = None,
401
- rotary_dim: Optional[int] = None,
402
- rotary_base: float = 10000.0,
403
- rotary_scale_base: Optional[float] = None,
404
- n_head: Optional[int] = None,
405
- n_head_kv: Optional[int] = None,
406
- head_dim: Optional[int] = None,
407
- bias: bool = True,
408
- causal: bool = True,
409
- softmax_scale: Optional[float] = None,
410
- layer_idx: Optional[int] = None,
411
- return_residual: bool = False,
412
- checkpointing: bool = False,
413
- ) -> None:
414
- super().__init__()
415
-
416
- # Rotary embedding
417
- self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
418
- if self.rotary_dim > 0:
419
- self.rotary_emb = RotaryEmbedding(
420
- d_rotary=self.rotary_dim,
421
- # d_rotary=math.ceil((rotary_dim // n_head) / 2), # d_rotary is half of d_head
422
- initial_cos_sin_cache_len=config.n_positions,
423
- )
424
-
425
- # MLP
426
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
427
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
428
- )
429
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
430
- hidden_size = config.n_embd
431
-
432
- linear_cls = nn.Linear
433
- if linear_cls is None:
434
- linear_cls = nn.Linear
435
-
436
- self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
437
- self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
438
-
439
- # Attention
440
- attn_cls = SelfAttention
441
- if attn_cls is None:
442
- attn_cls = SelfAttention
443
-
444
- cross_attn_cls = CrossAttention
445
- if cross_attn_cls is None:
446
- cross_attn_cls = CrossAttention
447
-
448
- self.inner_attn = attn_cls(
449
- causal=causal,
450
- softmax_scale=softmax_scale,
451
- attention_dropout=config.attn_pdrop,
452
- )
453
- self.inner_cross_attn = cross_attn_cls(
454
- causal=causal,
455
- softmax_scale=softmax_scale,
456
- attention_dropout=config.attn_pdrop,
457
- )
458
-
459
- self.layer_idx = layer_idx
460
- self.return_residual = return_residual
461
- self.checkpointing = checkpointing
462
-
463
- def _forward_self_attn(
464
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
465
- ) -> torch.FloatTensor:
466
- qkv = self.Wqkv(x)
467
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
468
-
469
- if self.rotary_dim > 0:
470
- qkv = self.rotary_emb(qkv)
471
-
472
- if self.checkpointing:
473
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
474
-
475
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
476
-
477
- def _forward_cross_attn(
478
- self,
479
- x: torch.FloatTensor,
480
- past_key_values: Optional[InferenceParams],
481
- key_padding_mask: Optional[torch.BoolTensor],
482
- ) -> torch.FloatTensor:
483
- batch_size = x.shape[0]
484
-
485
- qkv = self.Wqkv(x)
486
-
487
- q = qkv[..., : self.n_head * self.head_dim]
488
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
489
-
490
- kv = qkv[..., self.n_head * self.head_dim :]
491
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
492
-
493
- seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
494
- causal = None if seqlen_offset == 0 else False
495
- if self.rotary_dim > 0:
496
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
497
-
498
- if past_key_values is not None:
499
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
500
-
501
- if self.checkpointing:
502
- return torch.utils.checkpoint.checkpoint(
503
- self.inner_cross_attn,
504
- q,
505
- kv,
506
- key_padding_mask=key_padding_mask,
507
- causal=causal,
508
- )
509
-
510
- return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
511
-
512
- def forward(
513
- self,
514
- x: torch.FloatTensor,
515
- past_key_values: Optional[InferenceParams] = None,
516
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
517
- **kwargs,
518
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
519
- if attention_mask is not None:
520
- attention_mask = attention_mask.bool()
521
- else:
522
- attention_mask = None
523
-
524
- # MHA
525
- if self.n_head == self.n_head_kv:
526
- if past_key_values is None:
527
- # If `past_key_values` are not supplied, we run self-attention
528
- attn_output = self._forward_self_attn(x, attention_mask)
529
- else:
530
- # If `past_key_values` are supplied, it means that we might have cached values and
531
- # could take advantage of cross-attention
532
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
533
- # MQA / GQA
534
- else:
535
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
536
- # because `q` and `kv` lengths might be different
537
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
538
-
539
- output = rearrange(attn_output, "... h d -> ... (h d)")
540
- output = self.out_proj(output)
541
-
542
- return output if not self.return_residual else (output, x)
543
-
544
-
545
- class ParallelBlock(nn.Module):
546
- """Parallel block.
547
-
548
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
549
-
550
- """
551
-
552
- def __init__(
553
- self,
554
- config: PretrainedConfig,
555
- block_idx: Optional[int] = None,
556
- ) -> None:
557
- super().__init__()
558
-
559
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
560
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
561
- self.block_idx = block_idx
562
-
563
- self.mixer = MHA(config, layer_idx=block_idx)
564
- self.mlp = MLP(config)
565
-
566
- def forward(
567
- self,
568
- hidden_states: torch.FloatTensor,
569
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
570
- attention_mask: Optional[torch.BoolTensor] = None,
571
- **kwargs,
572
- ) -> torch.FloatTensor:
573
- residual = hidden_states
574
- hidden_states = self.ln(hidden_states)
575
-
576
- attn_outputs = self.mixer(
577
- hidden_states,
578
- past_key_values=past_key_values,
579
- attention_mask=attention_mask,
580
- )
581
- if isinstance(attn_outputs, tuple):
582
- attn_outputs = attn_outputs[0]
583
-
584
- attn_outputs = self.resid_dropout(attn_outputs)
585
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
586
-
587
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
588
-
589
- return hidden_states
590
-
591
-
592
- class CausalLMHead(nn.Module):
593
- """Causal Language Modeling head.
594
-
595
- Reference:
596
- Improving Language Understanding by Generative Pre-Training.
597
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
598
-
599
- """
600
-
601
- def __init__(self, config: PretrainedConfig) -> None:
602
- super().__init__()
603
-
604
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
605
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
606
-
607
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
608
- hidden_states = self.ln(hidden_states)
609
- logits = self.linear(hidden_states).to(torch.float32)
610
-
611
- return logits
612
-
613
-
614
- class CausalLMLoss(nn.Module):
615
- """Causal Language Modeling loss.
616
-
617
- Reference:
618
- Improving Language Understanding by Generative Pre-Training.
619
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
620
-
621
- """
622
-
623
- def __init__(self, shift_labels: bool = True) -> None:
624
- super().__init__()
625
-
626
- self.shift_labels = shift_labels
627
- self.loss_fct = nn.CrossEntropyLoss()
628
-
629
- def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
630
- if self.shift_labels:
631
- logits = logits[..., :-1, :].contiguous()
632
- labels = labels[..., 1:].contiguous()
633
-
634
- loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
635
-
636
- return loss
637
-
638
-
639
- class PhiPreTrainedModel(PreTrainedModel):
640
- """Phi pre-trained model."""
641
-
642
- config_class = PhiConfig
643
- base_model_prefix = "transformer"
644
- supports_gradient_checkpointing = False
645
- _no_split_modules = ["ParallelBlock"]
646
-
647
- def __init__(self, *inputs, **kwargs) -> None:
648
- super().__init__(*inputs, **kwargs)
649
-
650
- def _init_weights(self, module: nn.Module) -> None:
651
- if isinstance(module, (nn.Linear,)):
652
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
653
- if module.bias is not None:
654
- module.bias.data.zero_()
655
- elif isinstance(module, nn.Embedding):
656
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
657
- if module.padding_idx is not None:
658
- module.weight.data[module.padding_idx].zero_()
659
- elif isinstance(module, nn.LayerNorm):
660
- if module.bias is not None:
661
- module.bias.data.zero_()
662
- module.weight.data.fill_(1.0)
663
-
664
- def prepare_inputs_for_generation(
665
- self,
666
- input_ids: torch.LongTensor,
667
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
668
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
669
- **kwargs,
670
- ) -> Dict[str, Any]:
671
- if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
672
- past_key_values = InferenceParams(
673
- max_seqlen=self.config.n_positions,
674
- max_batch_size=input_ids.shape[0],
675
- seqlen_offset=0,
676
- batch_size_offset=0,
677
- key_value_memory_dict={},
678
- lengths_per_sample=None,
679
- )
680
- else:
681
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
682
- past_key_values.seqlen_offset = input_ids.shape[1] - 1
683
- input_ids = input_ids[:, -1].unsqueeze(-1)
684
-
685
- return {
686
- "input_ids": input_ids,
687
- "past_key_values": past_key_values,
688
- "attention_mask": attention_mask,
689
- }
690
-
691
-
692
- class PhiModel(PhiPreTrainedModel):
693
- """Phi model."""
694
-
695
- _keys_to_ignore_on_load_missing = [""]
696
- _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
697
-
698
- def __init__(self, config: PhiConfig) -> None:
699
- super().__init__(config)
700
-
701
- self.embd = Embedding(config)
702
- self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
703
- self.gradient_checkpointing = False
704
- self.post_init()
705
-
706
- def get_input_embeddings(self) -> nn.Embedding:
707
- return self.embd.wte
708
-
709
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
710
- self.embd.wte = new_embeddings
711
-
712
- def forward(
713
- self,
714
- input_ids: torch.LongTensor,
715
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
716
- attention_mask: Optional[torch.BoolTensor] = None,
717
- ) -> torch.FloatTensor:
718
- hidden_states = self.embd(input_ids)
719
-
720
- for layer in self.h:
721
- hidden_states = layer(
722
- hidden_states,
723
- past_key_values=past_key_values,
724
- attention_mask=attention_mask,
725
- )
726
-
727
- return hidden_states
728
-
729
-
730
- class PhiForCausalLM(PhiPreTrainedModel):
731
- """Phi for Causal Language Modeling."""
732
-
733
- _keys_to_ignore_on_load_missing = [""]
734
- _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
735
-
736
- def __init__(self, config: PhiConfig) -> None:
737
- super().__init__(config)
738
-
739
- self.transformer = PhiModel(config)
740
- self.lm_head = CausalLMHead(config)
741
- self.loss = CausalLMLoss()
742
-
743
- self.post_init()
744
-
745
- def get_output_embeddings(self) -> nn.Linear:
746
- return self.lm_head.linear
747
-
748
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
749
- self.lm_head.linear = new_embeddings
750
-
751
- def forward(
752
- self,
753
- input_ids: torch.LongTensor,
754
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
755
- attention_mask: Optional[torch.BoolTensor] = None,
756
- labels: Optional[torch.LongTensor] = None,
757
- **kwargs,
758
- ) -> CausalLMOutputWithPast:
759
- hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
760
- lm_logits = self.lm_head(hidden_states)
761
-
762
- loss = None
763
- if labels is not None:
764
- loss = self.loss(lm_logits, labels)
765
-
766
- return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
phi2_configuration.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class Phi2Config(PretrainedConfig):
6
+ model_type = "phi2" # not necessary unless you want to register model with auto classes
7
+ attribute_map = {
8
+ "max_position_embeddings": "initial_cos_sin_cache_len",
9
+ "hidden_size": "d_embedding",
10
+ "num_attention_heads": "n_attn_heads",
11
+ "num_hidden_layers": "n_blocks",
12
+ }
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size: int = 50295, # this includes the extra tokens included by Phi2 in tokenizer_config.json
17
+ vocab_chunk_for_gpu_efficiency: int = 64,
18
+ initial_cos_sin_cache_len: int = 2048,
19
+ d_embedding: int = 1024, # 2560?
20
+ n_blocks: int = 20, # 32?
21
+ n_attn_heads: int = 16, # 32?
22
+ use_flash_attn: bool = False,
23
+ use_flash_rotary: bool = False,
24
+ use_fused_dense: bool = False,
25
+ attn_pdrop: float = 0.0,
26
+ embd_pdrop: float = 0.0,
27
+ resid_pdrop: float = 0.0,
28
+ layer_norm_epsilon: float = 1e-5,
29
+ weight_initialization_range: float = 0.02,
30
+ tie_word_embeddings: bool = False, # whether embedding weights are shared between the encoder and decoder
31
+ checkpointing: bool = False, # whether to use gradient checkpointing to reduce memory usage (I think)
32
+ **kwargs
33
+ ) -> None:
34
+ self.vocab_size = (
35
+ math.ceil(
36
+ vocab_size / vocab_chunk_for_gpu_efficiency
37
+ ) * vocab_chunk_for_gpu_efficiency
38
+ )
39
+ self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
40
+ self.d_embedding = d_embedding
41
+ self.n_blocks = n_blocks
42
+ self.n_attn_heads = n_attn_heads
43
+ self.use_flash_attn = use_flash_attn
44
+ self.use_flash_rotary = use_flash_rotary
45
+ self.use_fused_dense = use_fused_dense
46
+ self.attn_pdrop = attn_pdrop
47
+ self.embd_pdrop = embd_pdrop
48
+ self.resid_pdrop = resid_pdrop
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.weight_initialization_range = weight_initialization_range
51
+ self.checkpointing = checkpointing
52
+
53
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ phi2_config = Phi2Config()
58
+ # phi2_config.save_pretrained("phi2_config")
59
+ # phi2_config = Phi2Config.from_pretrained("phi2_config")
60
+ # phi2_config.push_to_hub("phi2_config")
phi2_model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from typing import Any, cast
6
+
7
+ from .attention import ParallelAttentionBlock, KVCache
8
+ from .phi2_configuration import Phi2Config
9
+
10
+
11
+ class Phi2PreTrainedModel(PreTrainedModel):
12
+ config_class = Phi2Config # not necessary unless you want to register model with auto classes
13
+ supports_gradient_checkpointing = False
14
+ # _no_split_modules = ["ParallelAttentionBlock"]
15
+
16
+ # weight loading
17
+ # base_model_prefix = "transformer"
18
+ # _keys_to_ignore_on_load_missing = [""]
19
+ # _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
20
+
21
+ def __init__(self, config: Phi2Config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ def _init_weights(self, module: nn.Module) -> None:
26
+ # initialize weights - will get overwritten by saved weights in from_pretrained() if they exist
27
+ if isinstance(module, (nn.Linear,)):
28
+ module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
29
+ if module.bias is not None:
30
+ module.bias.data.zero_()
31
+ elif isinstance(module, nn.Embedding):
32
+ module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
33
+ if module.padding_idx is not None:
34
+ module.weight.data[module.padding_idx].zero_()
35
+ elif isinstance(module, nn.LayerNorm):
36
+ if module.bias is not None:
37
+ module.bias.data.zero_()
38
+ module.weight.data.fill_(1.0)
39
+
40
+ def prepare_inputs_for_generation(
41
+ self,
42
+ input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
43
+ kv_cache: KVCache | None = None,
44
+ key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
45
+ ) -> dict[str, Any]:
46
+ if not kv_cache:
47
+ kv_cache = KVCache(
48
+ max_seqlen=self.config.initial_cos_sin_cache_len,
49
+ max_batch_size=input_ids.shape[0],
50
+ seqlen_offset=0,
51
+ batch_size_offset=0,
52
+ kv_block_map={},
53
+ lengths_per_sample=None,
54
+ )
55
+ else:
56
+ # assume that `kv_cache` has cached all tokens up to the last token in `input_ids`
57
+ kv_cache.seqlen_offset = input_ids.shape[1] - 1
58
+ input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1))
59
+
60
+ return { # to be passed to forward()
61
+ "input_ids": input_ids,
62
+ "kv_cache": kv_cache,
63
+ "key_padding_mask": key_padding_mask,
64
+ }
65
+
66
+
67
+ class Embedding(nn.Module):
68
+ """Token embedding with dropout from Phi2."""
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size: int,
73
+ d_embedding: int,
74
+ embd_pdrop: float,
75
+ ) -> None:
76
+ super().__init__()
77
+ self.embeddings = nn.Embedding(vocab_size, d_embedding)
78
+ self.dropout = nn.Dropout(embd_pdrop)
79
+
80
+ def forward(
81
+ self,
82
+ input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
83
+ ) -> torch.FloatTensor:
84
+ x = self.embeddings( # dim: (batch_size, seq_len, d_embedding)
85
+ input_ids.view(-1, input_ids.size()[-1])
86
+ )
87
+ x = self.dropout(x)
88
+ return x
89
+
90
+
91
+ class Phi2Model(Phi2PreTrainedModel):
92
+ def __init__(self, config: Phi2Config) -> None:
93
+ super().__init__(config)
94
+ self.embedding = Embedding(
95
+ vocab_size=config.vocab_size,
96
+ d_embedding=config.d_embedding,
97
+ embd_pdrop=config.embd_pdrop,
98
+ )
99
+ self.parallel_blocks = nn.ModuleList([
100
+ ParallelAttentionBlock(
101
+ resid_pdrop=config.resid_pdrop,
102
+ layer_norm_epsilon=config.layer_norm_epsilon,
103
+ d_embedding=config.d_embedding,
104
+ n_attn_heads=config.n_attn_heads,
105
+ block_n=i,
106
+ initial_cos_sin_cache_len=config.initial_cos_sin_cache_len,
107
+ attn_pdrop=config.attn_pdrop,
108
+ use_flash_rotary=config.use_flash_rotary,
109
+ use_flash_attn=config.use_flash_attn,
110
+ use_fused_dense=config.use_fused_dense,
111
+ checkpointing=config.checkpointing,
112
+ )
113
+ for i in range(config.n_blocks)
114
+ ])
115
+ self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
116
+ self.post_init() # calls self._init_weights() for all modules
117
+
118
+ """
119
+ def get_input_embeddings(self) -> nn.Embedding:
120
+ return self.embedding.embeddings
121
+
122
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
123
+ self.embedding.embeddings = new_embeddings
124
+ """
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: torch.LongTensor,
129
+ kv_cache: KVCache | None = None,
130
+ key_padding_mask: torch.BoolTensor | None = None,
131
+ ) -> torch.FloatTensor:
132
+ x = self.embedding(input_ids)
133
+ for block in self.parallel_blocks:
134
+ x = block(
135
+ x,
136
+ kv_cache=kv_cache,
137
+ key_padding_mask=key_padding_mask,
138
+ )
139
+ return x
140
+
141
+
142
+ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
143
+ def __init__(self, config: Phi2Config) -> None:
144
+ super().__init__(config)
145
+ self.pretrained_model = Phi2Model(config)
146
+ self.layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
147
+ self.linear = nn.Linear(config.d_embedding, config.vocab_size)
148
+ self.loss_fn = nn.CrossEntropyLoss()
149
+ self.post_init() # calls self._init_weights() for all modules
150
+
151
+ def forward(
152
+ self,
153
+ input_ids: torch.LongTensor,
154
+ kv_cache: KVCache | None = None,
155
+ key_padding_mask: torch.BoolTensor | None = None,
156
+ labels: torch.LongTensor | None = None,
157
+ ) -> CausalLMOutputWithPast:
158
+ x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
159
+ x = self.layer_norm(x)
160
+ logits = self.linear(x).to(torch.float32)
161
+ loss = (
162
+ self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
163
+ if labels is not None
164
+ else None
165
+ )
166
+ return CausalLMOutputWithPast(loss=loss, logits=logits)