add gradient checkpointing for the final_layernorm module.

#77
by zhaoqf123 - opened
Files changed (1) hide show
  1. modeling_chatglm.py +4 -1
modeling_chatglm.py CHANGED
@@ -1012,7 +1012,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1012
  all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
1013
 
1014
  # Final layer norm.
1015
- hidden_states = self.final_layernorm(hidden_states)
 
 
 
1016
 
1017
  if output_hidden_states:
1018
  all_hidden_states = all_hidden_states + (hidden_states,)
 
1012
  all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
1013
 
1014
  # Final layer norm.
1015
+ if self.gradient_checkpointing and self.training:
1016
+ hidden_states = torch.utils.checkpoint.checkpoint(self.final_layernorm, hidden_states)
1017
+ else:
1018
+ hidden_states = self.final_layernorm(hidden_states)
1019
 
1020
  if output_hidden_states:
1021
  all_hidden_states = all_hidden_states + (hidden_states,)