abhi-mosaic commited on
Commit
f438eeb
·
1 Parent(s): e27b4b2

Add loss function and bool cast

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +12 -3
modeling_mpt.py CHANGED
@@ -56,7 +56,7 @@ class MPTModel(MPTPreTrainedModel):
56
  for module in self.modules():
57
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
58
  if config.verbose:
59
- print(f'Removing bias ({module.bias}) from {module}.')
60
  module.register_parameter('bias', None)
61
  if config.verbose and config.verbose > 2:
62
  print(self)
@@ -131,6 +131,10 @@ class MPTModel(MPTPreTrainedModel):
131
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
132
  return_dict = return_dict if return_dict is not None else self.config.return_dict
133
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
134
  if not return_dict:
135
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
136
  if output_attentions:
@@ -228,7 +232,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
228
  def get_decoder(self):
229
  return self.transformer
230
 
231
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
232
  return_dict = return_dict if return_dict is not None else self.config.return_dict
233
  use_cache = use_cache if use_cache is not None else self.config.use_cache
234
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
@@ -237,7 +241,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
237
  if self.logit_scale == 0:
238
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
239
  logits *= self.logit_scale
240
- return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
 
 
 
 
 
241
 
242
  def param_init_fn(self, module):
243
  init_fn_name = self.config.init_config['name']
 
56
  for module in self.modules():
57
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
58
  if config.verbose:
59
+ warnings.warn(f'Removing bias ({module.bias}) from {module}.')
60
  module.register_parameter('bias', None)
61
  if config.verbose and config.verbose > 2:
62
  print(self)
 
131
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
132
  return_dict = return_dict if return_dict is not None else self.config.return_dict
133
  use_cache = use_cache if use_cache is not None else self.config.use_cache
134
+ if attention_mask is not None:
135
+ attention_mask = attention_mask.bool()
136
+ if prefix_mask is not None:
137
+ prefix_mask = prefix_mask.bool()
138
  if not return_dict:
139
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
140
  if output_attentions:
 
232
  def get_decoder(self):
233
  return self.transformer
234
 
235
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
236
  return_dict = return_dict if return_dict is not None else self.config.return_dict
237
  use_cache = use_cache if use_cache is not None else self.config.use_cache
238
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
241
  if self.logit_scale == 0:
242
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
243
  logits *= self.logit_scale
244
+ loss = None
245
+ if labels is not None:
246
+ labels = torch.roll(labels, shifts=-1)
247
+ labels[:, -1] = -100
248
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
249
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
250
 
251
  def param_init_fn(self, module):
252
  init_fn_name = self.config.init_config['name']