pirroh commited on
Commit
3f27e64
1 Parent(s): b634e3d

Delete gpt_blocks.py

Browse files
Files changed (1) hide show
  1. gpt_blocks.py +0 -90
gpt_blocks.py DELETED
@@ -1,90 +0,0 @@
1
- # Copyright 2022 MosaicML Examples authors
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """GPT Blocks used for the GPT Model."""
5
-
6
- from typing import Optional, Tuple
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from .attention import MultiheadAttention
12
- from .low_precision_layernorm import LPLayerNorm
13
-
14
-
15
- class GPTMLP(nn.Module):
16
-
17
- def __init__(self,
18
- d_model: int,
19
- mlp_ratio: int,
20
- device: Optional[str] = None):
21
- super().__init__()
22
- self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
- self.mlp_act = nn.GELU(approximate='none')
24
- self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
- self.mlp_down._is_residual = True # type: ignore
26
-
27
- def forward(self, x):
28
- return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
-
30
-
31
- class GPTBlock(nn.Module):
32
-
33
- def __init__(self,
34
- attn_impl: str,
35
- d_model: int,
36
- n_heads: int,
37
- mlp_ratio: int,
38
- attn_clip_qkv: Optional[float] = None,
39
- attn_qk_ln: bool = False,
40
- softmax_scale: Optional[float] = None,
41
- attn_pdrop: float = 0.0,
42
- alibi: bool = False,
43
- resid_pdrop: float = 0.0,
44
- low_precision_layernorm: bool = False,
45
- device: Optional[str] = None,
46
- **kwargs):
47
- del kwargs # unused, just to capture any extra args from the config
48
- super().__init__()
49
-
50
- layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
-
52
- self.ln_1 = layernorm_class(d_model, device=device)
53
- self.attn = MultiheadAttention(
54
- attn_impl=attn_impl,
55
- attn_clip_qkv=attn_clip_qkv,
56
- attn_qk_ln=attn_qk_ln,
57
- softmax_scale=softmax_scale,
58
- attn_pdrop=attn_pdrop,
59
- d_model=d_model,
60
- n_heads=n_heads,
61
- device=device,
62
- )
63
- self.ln_2 = layernorm_class(d_model, device=device)
64
- self.mlp = GPTMLP(
65
- d_model=d_model,
66
- mlp_ratio=mlp_ratio,
67
- device=device,
68
- )
69
- self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
- self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
-
72
- def forward(
73
- self,
74
- x: torch.Tensor,
75
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
- attn_bias: Optional[torch.Tensor] = None,
77
- attention_mask: Optional[torch.ByteTensor] = None,
78
- is_causal: bool = True,
79
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
- a = self.ln_1(x)
81
- b, _, past_key_value = self.attn(a,
82
- past_key_value=past_key_value,
83
- attn_bias=attn_bias,
84
- attention_mask=attention_mask,
85
- is_causal=is_causal)
86
- x = x + self.resid_attn_dropout(b)
87
- m = self.ln_2(x)
88
- n = self.mlp(m)
89
- x = x + self.resid_mlp_dropout(n)
90
- return x, past_key_value