lvkaokao commited on
Commit
cac27b0
·
1 Parent(s): b9ee182

update return dict.

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +12 -2
modeling_mpt.py CHANGED
@@ -134,8 +134,8 @@ class MPTModel(MPTPreTrainedModel):
134
  attention_mask = attention_mask.bool()
135
  if prefix_mask is not None:
136
  prefix_mask = prefix_mask.bool()
137
- if not return_dict:
138
- raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
141
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
@@ -184,6 +184,9 @@ class MPTModel(MPTPreTrainedModel):
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
186
  x = self.norm_f(x)
 
 
 
187
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
188
 
189
  def param_init_fn(self, module):
@@ -234,6 +237,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
234
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
237
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
238
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
239
  if self.logit_scale is not None:
@@ -245,6 +251,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
245
  labels = torch.roll(labels, shifts=-1)
246
  labels[:, -1] = -100
247
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
 
 
 
 
248
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
249
 
250
  def param_init_fn(self, module):
 
134
  attention_mask = attention_mask.bool()
135
  if prefix_mask is not None:
136
  prefix_mask = prefix_mask.bool()
137
+ # if not return_dict:
138
+ # raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
141
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
 
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
186
  x = self.norm_f(x)
187
+ if not return_dict:
188
+ output = (x,) + (tuple(past_key_values),)
189
+ return output
190
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
191
 
192
  def param_init_fn(self, module):
 
237
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
238
  return_dict = return_dict if return_dict is not None else self.config.return_dict
239
  use_cache = use_cache if use_cache is not None else self.config.use_cache
240
+
241
+ past_key_values = list(past_key_values) if past_key_values is not None else None
242
+
243
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
244
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
245
  if self.logit_scale is not None:
 
251
  labels = torch.roll(labels, shifts=-1)
252
  labels[:, -1] = -100
253
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
254
+
255
+ if not return_dict:
256
+ output = (logits,) + (tuple(outputs[1]),)
257
+ return (loss,) + output if loss is not None else output
258
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
259
 
260
  def param_init_fn(self, module):