Alexandru Gherghescu
commited on
Commit
•
9dde569
1
Parent(s):
7e53000
Fix modeling_gpt1.py
Browse filesFew fixes:
- correctly norm before layers, as in the original Transformer paper
- fix model output
- fix LayerNorm instead of RMSNorm
- modeling_gpt1.py +13 -31
modeling_gpt1.py
CHANGED
@@ -9,33 +9,15 @@ from transformers.modeling_outputs import (
|
|
9 |
BaseModelOutput,
|
10 |
CausalLMOutput,
|
11 |
)
|
12 |
-
from transformers.activations import
|
13 |
|
14 |
from configuration_gpt1 import GPT1Config
|
15 |
|
16 |
|
17 |
-
class GPT1RMSNorm(nn.Module):
|
18 |
-
def __init__(self, config: GPT1Config):
|
19 |
-
super().__init__()
|
20 |
-
self.config = config
|
21 |
-
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
22 |
-
|
23 |
-
def _norm(self, x):
|
24 |
-
std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
|
25 |
-
return x / (std + self.config.layer_norm_eps)
|
26 |
-
|
27 |
-
def forward(self, hidden_state):
|
28 |
-
input_dtype = hidden_state.dtype
|
29 |
-
# compute in float32, not in fp16, since normalization needs to be accurate
|
30 |
-
hidden_state = hidden_state.float()
|
31 |
-
output = self._norm(hidden_state).type_as(input_dtype)
|
32 |
-
return output * self.weight
|
33 |
-
|
34 |
-
|
35 |
class GPT1MLP(nn.Module):
|
36 |
def __init__(self, config: GPT1Config):
|
37 |
super().__init__()
|
38 |
-
self.activation_fn =
|
39 |
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
40 |
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
41 |
|
@@ -109,25 +91,27 @@ class GPT1DecoderLayer(nn.Module):
|
|
109 |
self.attention = GPT1Attention(config)
|
110 |
self.mlp = GPT1MLP(config)
|
111 |
|
112 |
-
self.attention_norm =
|
113 |
-
|
|
|
|
|
114 |
|
115 |
self.res_dropout = nn.Dropout(p=config.resid_pdrop)
|
116 |
|
117 |
def forward(self, hidden_state, attn_mask):
|
118 |
# attention
|
119 |
residual = hidden_state
|
120 |
-
hidden_state = self.attention_norm(hidden_state)
|
121 |
hidden_state = self.attention(hidden_state, attn_mask)
|
122 |
hidden_state = self.res_dropout(hidden_state)
|
123 |
hidden_state = residual + hidden_state
|
|
|
124 |
|
125 |
# feed forward fully connected
|
126 |
residual = hidden_state
|
127 |
-
hidden_state = self.mlp_norm(hidden_state)
|
128 |
hidden_state = self.mlp(hidden_state)
|
129 |
hidden_state = self.res_dropout(hidden_state)
|
130 |
hidden_state = residual + hidden_state
|
|
|
131 |
|
132 |
return hidden_state
|
133 |
|
@@ -165,8 +149,6 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
165 |
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
166 |
)
|
167 |
|
168 |
-
self.norm = GPT1RMSNorm(config)
|
169 |
-
|
170 |
causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
|
171 |
fill_value=float('-inf'))
|
172 |
self.register_buffer('causal_mask',
|
@@ -182,9 +164,9 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
182 |
self.embs = value
|
183 |
|
184 |
def forward(self, input_ids, *args, **kwargs):
|
185 |
-
position_ids = torch.arange(input_ids.size(
|
186 |
dtype=torch.long,
|
187 |
-
device=input_ids.device)
|
188 |
|
189 |
input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
|
190 |
position_embeds = self.pos_emb(position_ids)
|
@@ -192,11 +174,10 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
192 |
|
193 |
causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
|
194 |
device=input_embeds.device)
|
|
|
195 |
for layer in self.layers:
|
196 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
197 |
|
198 |
-
hidden_state = self.norm(hidden_state)
|
199 |
-
|
200 |
return BaseModelOutput(
|
201 |
last_hidden_state=hidden_state
|
202 |
)
|
@@ -236,7 +217,8 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
|
|
236 |
def forward(self, input_ids, labels = None, *args, **kwargs):
|
237 |
output = self.model(input_ids)
|
238 |
|
239 |
-
|
|
|
240 |
|
241 |
loss = None
|
242 |
if labels is not None:
|
|
|
9 |
BaseModelOutput,
|
10 |
CausalLMOutput,
|
11 |
)
|
12 |
+
from transformers.activations import get_activation
|
13 |
|
14 |
from configuration_gpt1 import GPT1Config
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
class GPT1MLP(nn.Module):
|
18 |
def __init__(self, config: GPT1Config):
|
19 |
super().__init__()
|
20 |
+
self.activation_fn = get_activation(config.hidden_act)
|
21 |
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
22 |
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
23 |
|
|
|
91 |
self.attention = GPT1Attention(config)
|
92 |
self.mlp = GPT1MLP(config)
|
93 |
|
94 |
+
self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
|
95 |
+
eps=config.layer_norm_eps)
|
96 |
+
self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
|
97 |
+
eps=config.layer_norm_eps)
|
98 |
|
99 |
self.res_dropout = nn.Dropout(p=config.resid_pdrop)
|
100 |
|
101 |
def forward(self, hidden_state, attn_mask):
|
102 |
# attention
|
103 |
residual = hidden_state
|
|
|
104 |
hidden_state = self.attention(hidden_state, attn_mask)
|
105 |
hidden_state = self.res_dropout(hidden_state)
|
106 |
hidden_state = residual + hidden_state
|
107 |
+
hidden_state = self.attention_norm(hidden_state)
|
108 |
|
109 |
# feed forward fully connected
|
110 |
residual = hidden_state
|
|
|
111 |
hidden_state = self.mlp(hidden_state)
|
112 |
hidden_state = self.res_dropout(hidden_state)
|
113 |
hidden_state = residual + hidden_state
|
114 |
+
hidden_state = self.mlp_norm(hidden_state)
|
115 |
|
116 |
return hidden_state
|
117 |
|
|
|
149 |
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
150 |
)
|
151 |
|
|
|
|
|
152 |
causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
|
153 |
fill_value=float('-inf'))
|
154 |
self.register_buffer('causal_mask',
|
|
|
164 |
self.embs = value
|
165 |
|
166 |
def forward(self, input_ids, *args, **kwargs):
|
167 |
+
position_ids = torch.arange(input_ids.size(-1),
|
168 |
dtype=torch.long,
|
169 |
+
device=input_ids.device).unsqueeze_(0)
|
170 |
|
171 |
input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
|
172 |
position_embeds = self.pos_emb(position_ids)
|
|
|
174 |
|
175 |
causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
|
176 |
device=input_embeds.device)
|
177 |
+
|
178 |
for layer in self.layers:
|
179 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
180 |
|
|
|
|
|
181 |
return BaseModelOutput(
|
182 |
last_hidden_state=hidden_state
|
183 |
)
|
|
|
217 |
def forward(self, input_ids, labels = None, *args, **kwargs):
|
218 |
output = self.model(input_ids)
|
219 |
|
220 |
+
hidden_state = output[0]
|
221 |
+
logits = self.lm_head(hidden_state).float()
|
222 |
|
223 |
loss = None
|
224 |
if labels is not None:
|