acon96 commited on
Commit
78c707e
·
1 Parent(s): d318676

Gradient Checkpointing for HF Trainer

Browse files

Wire up existing checkpoint logic to work with transformers Trainer

Files changed (1) hide show
  1. modeling_phi.py +13 -7
modeling_phi.py CHANGED
@@ -605,9 +605,9 @@ class MHA(nn.Module):
605
  # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
606
  qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
607
 
608
- if self.checkpointing:
609
  attn_output = torch.utils.checkpoint.checkpoint(
610
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
611
  )
612
  else:
613
  attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
@@ -615,8 +615,8 @@ class MHA(nn.Module):
615
  # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
616
  return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
617
 
618
- if self.checkpointing:
619
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
620
 
621
  return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
622
 
@@ -664,7 +664,7 @@ class MHA(nn.Module):
664
 
665
  q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
666
 
667
- if self.checkpointing:
668
  attn_output = torch.utils.checkpoint.checkpoint(
669
  self.inner_cross_attn,
670
  q,
@@ -674,6 +674,7 @@ class MHA(nn.Module):
674
  max_seqlen=max_seqlen_q,
675
  cu_seqlens_k=cu_seqlens_k,
676
  max_seqlen_k=max_seqlen_k,
 
677
  )
678
  else:
679
  attn_output = self.inner_cross_attn(
@@ -692,13 +693,14 @@ class MHA(nn.Module):
692
  else attn_output
693
  )
694
 
695
- if self.checkpointing:
696
  return torch.utils.checkpoint.checkpoint(
697
  self.inner_cross_attn,
698
  q,
699
  kv,
700
  key_padding_mask=key_padding_mask,
701
  causal=causal,
 
702
  )
703
 
704
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
@@ -835,7 +837,7 @@ class PhiPreTrainedModel(PreTrainedModel):
835
 
836
  config_class = PhiConfig
837
  base_model_prefix = "transformer"
838
- supports_gradient_checkpointing = False
839
  _no_split_modules = ["ParallelBlock"]
840
 
841
  def __init__(self, *inputs, **kwargs) -> None:
@@ -855,6 +857,10 @@ class PhiPreTrainedModel(PreTrainedModel):
855
  module.bias.data.zero_()
856
  module.weight.data.fill_(1.0)
857
 
 
 
 
 
858
  def prepare_inputs_for_generation(
859
  self,
860
  input_ids: torch.LongTensor,
 
605
  # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
606
  qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
607
 
608
+ if self.checkpointing and self.training:
609
  attn_output = torch.utils.checkpoint.checkpoint(
610
+ self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, use_reentrant=False
611
  )
612
  else:
613
  attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
 
615
  # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
616
  return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
617
 
618
+ if self.checkpointing and self.training:
619
+ return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask, use_reentrant=False)
620
 
621
  return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
622
 
 
664
 
665
  q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
666
 
667
+ if self.checkpointing and self.training:
668
  attn_output = torch.utils.checkpoint.checkpoint(
669
  self.inner_cross_attn,
670
  q,
 
674
  max_seqlen=max_seqlen_q,
675
  cu_seqlens_k=cu_seqlens_k,
676
  max_seqlen_k=max_seqlen_k,
677
+ use_reentrant=False
678
  )
679
  else:
680
  attn_output = self.inner_cross_attn(
 
693
  else attn_output
694
  )
695
 
696
+ if self.checkpointing and self.training:
697
  return torch.utils.checkpoint.checkpoint(
698
  self.inner_cross_attn,
699
  q,
700
  kv,
701
  key_padding_mask=key_padding_mask,
702
  causal=causal,
703
+ use_reentrant=False
704
  )
705
 
706
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
 
837
 
838
  config_class = PhiConfig
839
  base_model_prefix = "transformer"
840
+ supports_gradient_checkpointing = True
841
  _no_split_modules = ["ParallelBlock"]
842
 
843
  def __init__(self, *inputs, **kwargs) -> None:
 
857
  module.bias.data.zero_()
858
  module.weight.data.fill_(1.0)
859
 
860
+ def _set_gradient_checkpointing(self, module, value=False):
861
+ if isinstance(module, MHA):
862
+ module.checkpointing = value
863
+
864
  def prepare_inputs_for_generation(
865
  self,
866
  input_ids: torch.LongTensor,