zxdu20 commited on
Commit
5c64357
1 Parent(s): 8127ab6

Set ignore_index for CrossEntropyLoss

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1124
  shift_logits = lm_logits[..., :-1, :].contiguous()
1125
  shift_labels = labels[..., 1:].contiguous()
1126
  # Flatten the tokens
1127
- loss_fct = CrossEntropyLoss()
1128
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1129
 
1130
  lm_logits = lm_logits.to(hidden_states.dtype)
 
1124
  shift_logits = lm_logits[..., :-1, :].contiguous()
1125
  shift_labels = labels[..., 1:].contiguous()
1126
  # Flatten the tokens
1127
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1128
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1129
 
1130
  lm_logits = lm_logits.to(hidden_states.dtype)