jon-tow commited on
Commit
a4750ac
1 Parent(s): 846682b

fix: init proper term

Browse files
Files changed (1) hide show
  1. modeling_stablelm_epoch.py +2 -2
modeling_stablelm_epoch.py CHANGED
@@ -609,7 +609,7 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
609
  hidden_states = outputs[0]
610
  logits = self.lm_head(hidden_states).float()
611
 
612
- lm_loss = None
613
  if labels is not None:
614
  # Shift so that tokens < n predict n
615
  shift_logits = logits[..., :-1, :].contiguous()
@@ -627,7 +627,7 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
627
  return (loss,) + output if loss is not None else output
628
 
629
  return CausalLMOutputWithPast(
630
- loss=lm_loss,
631
  logits=logits,
632
  past_key_values=outputs.past_key_values,
633
  hidden_states=outputs.hidden_states,
 
609
  hidden_states = outputs[0]
610
  logits = self.lm_head(hidden_states).float()
611
 
612
+ loss = None
613
  if labels is not None:
614
  # Shift so that tokens < n predict n
615
  shift_logits = logits[..., :-1, :].contiguous()
 
627
  return (loss,) + output if loss is not None else output
628
 
629
  return CausalLMOutputWithPast(
630
+ loss=loss,
631
  logits=logits,
632
  past_key_values=outputs.past_key_values,
633
  hidden_states=outputs.hidden_states,