fix: init proper term
Browse files
modeling_stablelm_epoch.py
CHANGED
@@ -609,7 +609,7 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
|
|
609 |
hidden_states = outputs[0]
|
610 |
logits = self.lm_head(hidden_states).float()
|
611 |
|
612 |
-
|
613 |
if labels is not None:
|
614 |
# Shift so that tokens < n predict n
|
615 |
shift_logits = logits[..., :-1, :].contiguous()
|
@@ -627,7 +627,7 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
|
|
627 |
return (loss,) + output if loss is not None else output
|
628 |
|
629 |
return CausalLMOutputWithPast(
|
630 |
-
loss=
|
631 |
logits=logits,
|
632 |
past_key_values=outputs.past_key_values,
|
633 |
hidden_states=outputs.hidden_states,
|
|
|
609 |
hidden_states = outputs[0]
|
610 |
logits = self.lm_head(hidden_states).float()
|
611 |
|
612 |
+
loss = None
|
613 |
if labels is not None:
|
614 |
# Shift so that tokens < n predict n
|
615 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
627 |
return (loss,) + output if loss is not None else output
|
628 |
|
629 |
return CausalLMOutputWithPast(
|
630 |
+
loss=loss,
|
631 |
logits=logits,
|
632 |
past_key_values=outputs.past_key_values,
|
633 |
hidden_states=outputs.hidden_states,
|