DanielHesslow commited on
Commit
fde9173
1 Parent(s): d4ba49b
Files changed (1) hide show
  1. rita_modeling.py +42 -9
rita_modeling.py CHANGED
@@ -13,6 +13,7 @@ from transformers.modeling_outputs import (
13
  BaseModelOutputWithPastAndCrossAttentions,
14
  CausalLMOutputWithCrossAttentions,
15
  CausalLMOutputWithPast,
 
16
  )
17
 
18
  from transformers.modeling_utils import PreTrainedModel
@@ -222,18 +223,50 @@ class RITAModel(PreTrainedModel):
222
  self.final_norm = nn.LayerNorm(config.d_model)
223
  self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
224
 
225
- def forward(self, input_ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
226
- x = self.embedding(input_ids) # N x L x D
227
- if attn_mask == None:
228
- attn_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  for layer in self.layers:
230
- x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
231
  x = self.final_norm(x) # N x L x D
232
 
233
- if return_hidden:
234
- return x
235
- else:
236
- return self.projector(x)
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  #Some common HF functions.
239
  def get_input_embeddings(self):
 
13
  BaseModelOutputWithPastAndCrossAttentions,
14
  CausalLMOutputWithCrossAttentions,
15
  CausalLMOutputWithPast,
16
+ CausalLMOutput,
17
  )
18
 
19
  from transformers.modeling_utils import PreTrainedModel
 
223
  self.final_norm = nn.LayerNorm(config.d_model)
224
  self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
225
 
226
+ def forward(
227
+ self,
228
+ input_ids=None,
229
+ past_key_values=None,
230
+ attention_mask=None,
231
+ token_type_ids=None,
232
+ position_ids=None,
233
+ head_mask=None,
234
+ inputs_embeds=None,
235
+ encoder_hidden_states=None,
236
+ encoder_attention_mask=None,
237
+ labels=None,
238
+ use_cache=None,
239
+ output_attentions=None,
240
+ output_hidden_states=None,
241
+ return_dict=None) -> torch.FloatTensor:
242
+
243
+ if inputs_embeds == None:
244
+ x = self.embedding(input_ids) # N x L x D
245
+ else:
246
+ x = inputs_embeds
247
+
248
+ if attention_mask == None:
249
+ attention_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
250
  for layer in self.layers:
251
+ x = layer(x, attn_mask=attention_mask)
252
  x = self.final_norm(x) # N x L x D
253
 
254
+ logits = self.projector(x)
255
+ loss = None
256
+ if labels is not None:
257
+ # Shift so that tokens < n predict n
258
+ shift_logits = logits[..., :-1, :].contiguous()
259
+ shift_labels = labels[..., 1:].contiguous()
260
+ # Flatten the tokens
261
+ loss_fct = CrossEntropyLoss()
262
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
263
+
264
+ return CausalLMOutput(
265
+ loss=loss,
266
+ logits=logits,
267
+ hidden_states=x,
268
+ )
269
+
270
 
271
  #Some common HF functions.
272
  def get_input_embeddings(self):