Alexandru Gherghescu commited on
Commit
9dde569
1 Parent(s): 7e53000

Fix modeling_gpt1.py

Browse files

Few fixes:
- correctly norm before layers, as in the original Transformer paper
- fix model output
- fix LayerNorm instead of RMSNorm

Files changed (1) hide show
  1. modeling_gpt1.py +13 -31
modeling_gpt1.py CHANGED
@@ -9,33 +9,15 @@ from transformers.modeling_outputs import (
9
  BaseModelOutput,
10
  CausalLMOutput,
11
  )
12
- from transformers.activations import ACT2FN
13
 
14
  from configuration_gpt1 import GPT1Config
15
 
16
 
17
- class GPT1RMSNorm(nn.Module):
18
- def __init__(self, config: GPT1Config):
19
- super().__init__()
20
- self.config = config
21
- self.weight = nn.Parameter(torch.ones(config.hidden_size))
22
-
23
- def _norm(self, x):
24
- std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
25
- return x / (std + self.config.layer_norm_eps)
26
-
27
- def forward(self, hidden_state):
28
- input_dtype = hidden_state.dtype
29
- # compute in float32, not in fp16, since normalization needs to be accurate
30
- hidden_state = hidden_state.float()
31
- output = self._norm(hidden_state).type_as(input_dtype)
32
- return output * self.weight
33
-
34
-
35
  class GPT1MLP(nn.Module):
36
  def __init__(self, config: GPT1Config):
37
  super().__init__()
38
- self.activation_fn = ACT2FN(config.hidden_act)
39
  self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
40
  self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
41
 
@@ -109,25 +91,27 @@ class GPT1DecoderLayer(nn.Module):
109
  self.attention = GPT1Attention(config)
110
  self.mlp = GPT1MLP(config)
111
 
112
- self.attention_norm = GPT1RMSNorm(config)
113
- self.mlp_norm = GPT1RMSNorm(config)
 
 
114
 
115
  self.res_dropout = nn.Dropout(p=config.resid_pdrop)
116
 
117
  def forward(self, hidden_state, attn_mask):
118
  # attention
119
  residual = hidden_state
120
- hidden_state = self.attention_norm(hidden_state)
121
  hidden_state = self.attention(hidden_state, attn_mask)
122
  hidden_state = self.res_dropout(hidden_state)
123
  hidden_state = residual + hidden_state
 
124
 
125
  # feed forward fully connected
126
  residual = hidden_state
127
- hidden_state = self.mlp_norm(hidden_state)
128
  hidden_state = self.mlp(hidden_state)
129
  hidden_state = self.res_dropout(hidden_state)
130
  hidden_state = residual + hidden_state
 
131
 
132
  return hidden_state
133
 
@@ -165,8 +149,6 @@ class GPT1Model(GPT1PreTrainedModel):
165
  [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
166
  )
167
 
168
- self.norm = GPT1RMSNorm(config)
169
-
170
  causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
171
  fill_value=float('-inf'))
172
  self.register_buffer('causal_mask',
@@ -182,9 +164,9 @@ class GPT1Model(GPT1PreTrainedModel):
182
  self.embs = value
183
 
184
  def forward(self, input_ids, *args, **kwargs):
185
- position_ids = torch.arange(input_ids.size()[-1],
186
  dtype=torch.long,
187
- device=input_ids.device)
188
 
189
  input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
190
  position_embeds = self.pos_emb(position_ids)
@@ -192,11 +174,10 @@ class GPT1Model(GPT1PreTrainedModel):
192
 
193
  causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
194
  device=input_embeds.device)
 
195
  for layer in self.layers:
196
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
197
 
198
- hidden_state = self.norm(hidden_state)
199
-
200
  return BaseModelOutput(
201
  last_hidden_state=hidden_state
202
  )
@@ -236,7 +217,8 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
236
  def forward(self, input_ids, labels = None, *args, **kwargs):
237
  output = self.model(input_ids)
238
 
239
- logits = self.lm_head(output).float()
 
240
 
241
  loss = None
242
  if labels is not None:
 
9
  BaseModelOutput,
10
  CausalLMOutput,
11
  )
12
+ from transformers.activations import get_activation
13
 
14
  from configuration_gpt1 import GPT1Config
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class GPT1MLP(nn.Module):
18
  def __init__(self, config: GPT1Config):
19
  super().__init__()
20
+ self.activation_fn = get_activation(config.hidden_act)
21
  self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
22
  self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
23
 
 
91
  self.attention = GPT1Attention(config)
92
  self.mlp = GPT1MLP(config)
93
 
94
+ self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
95
+ eps=config.layer_norm_eps)
96
+ self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
97
+ eps=config.layer_norm_eps)
98
 
99
  self.res_dropout = nn.Dropout(p=config.resid_pdrop)
100
 
101
  def forward(self, hidden_state, attn_mask):
102
  # attention
103
  residual = hidden_state
 
104
  hidden_state = self.attention(hidden_state, attn_mask)
105
  hidden_state = self.res_dropout(hidden_state)
106
  hidden_state = residual + hidden_state
107
+ hidden_state = self.attention_norm(hidden_state)
108
 
109
  # feed forward fully connected
110
  residual = hidden_state
 
111
  hidden_state = self.mlp(hidden_state)
112
  hidden_state = self.res_dropout(hidden_state)
113
  hidden_state = residual + hidden_state
114
+ hidden_state = self.mlp_norm(hidden_state)
115
 
116
  return hidden_state
117
 
 
149
  [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
150
  )
151
 
 
 
152
  causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
153
  fill_value=float('-inf'))
154
  self.register_buffer('causal_mask',
 
164
  self.embs = value
165
 
166
  def forward(self, input_ids, *args, **kwargs):
167
+ position_ids = torch.arange(input_ids.size(-1),
168
  dtype=torch.long,
169
+ device=input_ids.device).unsqueeze_(0)
170
 
171
  input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
172
  position_embeds = self.pos_emb(position_ids)
 
174
 
175
  causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
176
  device=input_embeds.device)
177
+
178
  for layer in self.layers:
179
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
180
 
 
 
181
  return BaseModelOutput(
182
  last_hidden_state=hidden_state
183
  )
 
217
  def forward(self, input_ids, labels = None, *args, **kwargs):
218
  output = self.model(input_ids)
219
 
220
+ hidden_state = output[0]
221
+ logits = self.lm_head(hidden_state).float()
222
 
223
  loss = None
224
  if labels is not None: