Alexandru Gherghescu commited on
Commit
fe8246f
1 Parent(s): 030a9e9

Fix modeling_gpt1.py

Browse files

Fix an issue with the attention mask, where its size would not be
correct during training and inference.

Files changed (1) hide show
  1. modeling_gpt1.py +19 -9
modeling_gpt1.py CHANGED
@@ -154,6 +154,7 @@ class GPT1Model(GPT1PreTrainedModel):
154
  self.register_buffer('causal_mask',
155
  torch.triu(causal_mask, diagonal=1),
156
  persistent=False)
 
157
 
158
  self.post_init()
159
 
@@ -172,12 +173,18 @@ class GPT1Model(GPT1PreTrainedModel):
172
  position_embeds = self.pos_emb(position_ids)
173
  hidden_state = self.embs_dropout(input_embeds) + position_embeds
174
 
175
- if attention_mask is not None:
176
- causal_mask = attention_mask.to(dtype=input_embeds.dtype,
177
- device=input_embeds.device)
178
- else:
179
- causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
180
- device=input_embeds.device)
 
 
 
 
 
 
181
 
182
  for layer in self.layers:
183
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
@@ -240,10 +247,13 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
240
  logits=logits
241
  )
242
 
243
- def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
244
- seq_len = input_ids.size(1)
 
 
 
245
 
246
- attn_mask = torch.full((1, seq_len, seq_len), fill_value=float('-inf'))
247
  attn_mask = torch.triu(attn_mask, diagonal=1)
248
 
249
  return {
 
154
  self.register_buffer('causal_mask',
155
  torch.triu(causal_mask, diagonal=1),
156
  persistent=False)
157
+ self.mask_cache_len = config.max_position_embeddings
158
 
159
  self.post_init()
160
 
 
173
  position_embeds = self.pos_emb(position_ids)
174
  hidden_state = self.embs_dropout(input_embeds) + position_embeds
175
 
176
+ if attention_mask is not None and attention_mask.size(1) > self.mask_cache_len:
177
+ seq_len = attention_mask.size(1)
178
+ self.mask_cache_len = seq_len
179
+
180
+ causal_mask = torch.full((seq_len, seq_len),
181
+ fill_value=float('-inf'))
182
+ self.register_buffer('causal_mask',
183
+ torch.triu(causal_mask, diagonal=1),
184
+ persistent=False)
185
+
186
+ causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
187
+ device=input_embeds.device)
188
 
189
  for layer in self.layers:
190
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
 
247
  logits=logits
248
  )
249
 
250
+ def prepare_inputs_for_generation(self, input_ids, attention_mask,
251
+ *args, **kwargs):
252
+ assert attention_mask.size(1) == input_ids.size(1)
253
+
254
+ seq_len = attention_mask.size(1)
255
 
256
+ attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
257
  attn_mask = torch.triu(attn_mask, diagonal=1)
258
 
259
  return {