noahtren commited on
Commit
e282cdc
1 Parent(s): d318676

return past hidden states when `output_hidden_states` provided

Browse files

The model should be able to return past hidden states, as expected in the `forward()` function for HuggingFace models.

Added this because I needed it, but additional support for `output_attentions` param could be added.

Files changed (1) hide show
  1. modeling_phi.py +2 -1
modeling_phi.py CHANGED
@@ -947,6 +947,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
947
  input_ids: torch.LongTensor,
948
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
949
  attention_mask: Optional[torch.BoolTensor] = None,
 
950
  labels: Optional[torch.LongTensor] = None,
951
  **kwargs,
952
  ) -> CausalLMOutputWithPast:
@@ -957,4 +958,4 @@ class PhiForCausalLM(PhiPreTrainedModel):
957
  if labels is not None:
958
  loss = self.loss(lm_logits, labels)
959
 
960
- return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
947
  input_ids: torch.LongTensor,
948
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
949
  attention_mask: Optional[torch.BoolTensor] = None,
950
+ output_hidden_states: Optional[bool] = None,
951
  labels: Optional[torch.LongTensor] = None,
952
  **kwargs,
953
  ) -> CausalLMOutputWithPast:
 
958
  if labels is not None:
959
  loss = self.loss(lm_logits, labels)
960
 
961
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values, hidden_states=hidden_states if output_hidden_states else None)