DanielHesslow commited on
Commit
568c417
·
1 Parent(s): 7eb88c8

Update to follow HF naming scheme

Browse files
Files changed (1) hide show
  1. rita_modeling.py +21 -17
rita_modeling.py CHANGED
@@ -129,8 +129,8 @@ class SelfAttention(nn.Module):
129
  def forward(
130
  self,
131
  x,
132
- attn_mask: Optional[torch.BoolTensor] = None,
133
- padding_mask: Optional[torch.BoolTensor] = None,
134
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
135
 
136
  N, L, D = x.size() # Batch_size, Context_size, d_model
@@ -153,14 +153,14 @@ class SelfAttention(nn.Module):
153
  # causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L)
154
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
155
 
156
- if attn_mask is not None:
157
- att[:,:,-L:, -L: ].masked_fill_(attn_mask.view(1, 1, L, L), float("-inf"))
158
 
159
  att = (
160
  att.transpose(0, 2)
161
- .masked_fill(padding_mask.view(1, 1, N, L), float("-inf"))
162
  .transpose(0, 2)
163
- if padding_mask is not None
164
  else att
165
  )
166
 
@@ -197,11 +197,11 @@ class DecoderLayer(nn.Module):
197
  def forward(
198
  self,
199
  x: torch.FloatTensor,
200
- attn_mask: torch.BoolTensor,
201
- padding_mask: Optional[torch.BoolTensor] = None,
202
  ) -> torch.FloatTensor:
203
  y = self.attn_norm(x)
204
- y = self.self_attention(y, attn_mask=attn_mask, padding_mask=padding_mask)
205
  x = x + self.attn_dropout(y)
206
 
207
  y = self.mlp_norm(x)
@@ -228,27 +228,27 @@ class RITAModel(PreTrainedModel):
228
  input_ids=None,
229
  past_key_values=None, # NOT USED
230
  attention_mask=None,
 
231
  token_type_ids=None, # NOT USED
232
  position_ids=None, # NOT USED
233
  head_mask=None, # NOT USED
234
  inputs_embeds=None,
235
  encoder_hidden_states=None, # NOT USED
236
- encoder_attention_mask=None, # NOT USED
237
  labels=None,
238
  use_cache=None, # NOT USED
239
  output_attentions=None, # NOT USED
240
  output_hidden_states=None, # NOT USED
241
  return_dict=None # NOT USED
242
  ) -> torch.FloatTensor:
243
-
244
  if inputs_embeds == None:
245
  x = self.embedding(input_ids) # N x L x D
246
  else:
247
  x = inputs_embeds
248
- if attention_mask == None:
249
- attention_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
250
  for layer in self.layers:
251
- x = layer(x, attn_mask=attention_mask)
252
  x = self.final_norm(x) # N x L x D
253
 
254
  return BaseModelOutput(
@@ -295,23 +295,25 @@ class RITAModelForCausalLM(PreTrainedModel):
295
  input_ids=None,
296
  past_key_values=None, # NOT USED
297
  attention_mask=None,
 
298
  token_type_ids=None, # NOT USED
299
  position_ids=None, # NOT USED
300
  head_mask=None, # NOT USED
301
  inputs_embeds=None,
302
  encoder_hidden_states=None, # NOT USED
303
- encoder_attention_mask=None, # NOT USED
304
  labels=None,
305
  use_cache=None, # NOT USED
306
  output_attentions=None, # NOT USED
307
  output_hidden_states=None, # NOT USED
308
  return_dict=None # NOT USED
309
  ) -> torch.FloatTensor:
310
-
311
  transformer_outputs = self.transformer(
312
  input_ids,
313
  past_key_values=past_key_values,
314
- attention_mask=attention_mask,
 
315
  token_type_ids=token_type_ids,
316
  position_ids=position_ids,
317
  head_mask=head_mask,
@@ -382,6 +384,7 @@ class RITAModelForSequenceClassification(PreTrainedModel):
382
  input_ids=None,
383
  past_key_values=None,
384
  attention_mask=None,
 
385
  token_type_ids=None,
386
  position_ids=None,
387
  head_mask=None,
@@ -404,6 +407,7 @@ class RITAModelForSequenceClassification(PreTrainedModel):
404
  input_ids,
405
  past_key_values=past_key_values,
406
  attention_mask=attention_mask,
 
407
  token_type_ids=token_type_ids,
408
  position_ids=position_ids,
409
  head_mask=head_mask,
 
129
  def forward(
130
  self,
131
  x,
132
+ causal_mask: Optional[torch.BoolTensor] = None,
133
+ attention_mask: Optional[torch.BoolTensor] = None,
134
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
135
 
136
  N, L, D = x.size() # Batch_size, Context_size, d_model
 
153
  # causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L)
154
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
155
 
156
+ if causal_mask is not None:
157
+ att[:,:,-L:, -L: ].masked_fill_(causal_mask.view(1, 1, L, L), float("-inf"))
158
 
159
  att = (
160
  att.transpose(0, 2)
161
+ .masked_fill(attention_mask.view(1, 1, N, L)==0, float("-inf"))
162
  .transpose(0, 2)
163
+ if attention_mask is not None
164
  else att
165
  )
166
 
 
197
  def forward(
198
  self,
199
  x: torch.FloatTensor,
200
+ causal_mask: torch.BoolTensor,
201
+ attention_mask: Optional[torch.BoolTensor] = None,
202
  ) -> torch.FloatTensor:
203
  y = self.attn_norm(x)
204
+ y = self.self_attention(y, causal_mask=causal_mask, attention_mask=attention_mask)
205
  x = x + self.attn_dropout(y)
206
 
207
  y = self.mlp_norm(x)
 
228
  input_ids=None,
229
  past_key_values=None, # NOT USED
230
  attention_mask=None,
231
+ causal_mask=None,
232
  token_type_ids=None, # NOT USED
233
  position_ids=None, # NOT USED
234
  head_mask=None, # NOT USED
235
  inputs_embeds=None,
236
  encoder_hidden_states=None, # NOT USED
237
+ encoder_causal_mask=None, # NOT USED
238
  labels=None,
239
  use_cache=None, # NOT USED
240
  output_attentions=None, # NOT USED
241
  output_hidden_states=None, # NOT USED
242
  return_dict=None # NOT USED
243
  ) -> torch.FloatTensor:
 
244
  if inputs_embeds == None:
245
  x = self.embedding(input_ids) # N x L x D
246
  else:
247
  x = inputs_embeds
248
+ if causal_mask == None:
249
+ causal_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
250
  for layer in self.layers:
251
+ x = layer(x, causal_mask=causal_mask, attention_mask=attention_mask)
252
  x = self.final_norm(x) # N x L x D
253
 
254
  return BaseModelOutput(
 
295
  input_ids=None,
296
  past_key_values=None, # NOT USED
297
  attention_mask=None,
298
+ causal_mask=None,
299
  token_type_ids=None, # NOT USED
300
  position_ids=None, # NOT USED
301
  head_mask=None, # NOT USED
302
  inputs_embeds=None,
303
  encoder_hidden_states=None, # NOT USED
304
+ encoder_causal_mask=None, # NOT USED
305
  labels=None,
306
  use_cache=None, # NOT USED
307
  output_attentions=None, # NOT USED
308
  output_hidden_states=None, # NOT USED
309
  return_dict=None # NOT USED
310
  ) -> torch.FloatTensor:
311
+
312
  transformer_outputs = self.transformer(
313
  input_ids,
314
  past_key_values=past_key_values,
315
+ causal_mask=causal_mask,
316
+ attention_mask = attention_mask,
317
  token_type_ids=token_type_ids,
318
  position_ids=position_ids,
319
  head_mask=head_mask,
 
384
  input_ids=None,
385
  past_key_values=None,
386
  attention_mask=None,
387
+ causal_mask=None,
388
  token_type_ids=None,
389
  position_ids=None,
390
  head_mask=None,
 
407
  input_ids,
408
  past_key_values=past_key_values,
409
  attention_mask=attention_mask,
410
+ causal_mask=causal_mask,
411
  token_type_ids=token_type_ids,
412
  position_ids=position_ids,
413
  head_mask=head_mask,