Gradient Checkpointing for HF Trainer
Browse filesWire up existing checkpoint logic to work with transformers Trainer
- modeling_phi.py +13 -7
modeling_phi.py
CHANGED
@@ -605,9 +605,9 @@ class MHA(nn.Module):
|
|
605 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
606 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
607 |
|
608 |
-
if self.checkpointing:
|
609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
610 |
-
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
611 |
)
|
612 |
else:
|
613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
@@ -615,8 +615,8 @@ class MHA(nn.Module):
|
|
615 |
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
617 |
|
618 |
-
if self.checkpointing:
|
619 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
620 |
|
621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
622 |
|
@@ -664,7 +664,7 @@ class MHA(nn.Module):
|
|
664 |
|
665 |
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
666 |
|
667 |
-
if self.checkpointing:
|
668 |
attn_output = torch.utils.checkpoint.checkpoint(
|
669 |
self.inner_cross_attn,
|
670 |
q,
|
@@ -674,6 +674,7 @@ class MHA(nn.Module):
|
|
674 |
max_seqlen=max_seqlen_q,
|
675 |
cu_seqlens_k=cu_seqlens_k,
|
676 |
max_seqlen_k=max_seqlen_k,
|
|
|
677 |
)
|
678 |
else:
|
679 |
attn_output = self.inner_cross_attn(
|
@@ -692,13 +693,14 @@ class MHA(nn.Module):
|
|
692 |
else attn_output
|
693 |
)
|
694 |
|
695 |
-
if self.checkpointing:
|
696 |
return torch.utils.checkpoint.checkpoint(
|
697 |
self.inner_cross_attn,
|
698 |
q,
|
699 |
kv,
|
700 |
key_padding_mask=key_padding_mask,
|
701 |
causal=causal,
|
|
|
702 |
)
|
703 |
|
704 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
@@ -835,7 +837,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
835 |
|
836 |
config_class = PhiConfig
|
837 |
base_model_prefix = "transformer"
|
838 |
-
supports_gradient_checkpointing =
|
839 |
_no_split_modules = ["ParallelBlock"]
|
840 |
|
841 |
def __init__(self, *inputs, **kwargs) -> None:
|
@@ -855,6 +857,10 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
855 |
module.bias.data.zero_()
|
856 |
module.weight.data.fill_(1.0)
|
857 |
|
|
|
|
|
|
|
|
|
858 |
def prepare_inputs_for_generation(
|
859 |
self,
|
860 |
input_ids: torch.LongTensor,
|
|
|
605 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
606 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
607 |
|
608 |
+
if self.checkpointing and self.training:
|
609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
610 |
+
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, use_reentrant=False
|
611 |
)
|
612 |
else:
|
613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
|
615 |
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
617 |
|
618 |
+
if self.checkpointing and self.training:
|
619 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask, use_reentrant=False)
|
620 |
|
621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
622 |
|
|
|
664 |
|
665 |
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
666 |
|
667 |
+
if self.checkpointing and self.training:
|
668 |
attn_output = torch.utils.checkpoint.checkpoint(
|
669 |
self.inner_cross_attn,
|
670 |
q,
|
|
|
674 |
max_seqlen=max_seqlen_q,
|
675 |
cu_seqlens_k=cu_seqlens_k,
|
676 |
max_seqlen_k=max_seqlen_k,
|
677 |
+
use_reentrant=False
|
678 |
)
|
679 |
else:
|
680 |
attn_output = self.inner_cross_attn(
|
|
|
693 |
else attn_output
|
694 |
)
|
695 |
|
696 |
+
if self.checkpointing and self.training:
|
697 |
return torch.utils.checkpoint.checkpoint(
|
698 |
self.inner_cross_attn,
|
699 |
q,
|
700 |
kv,
|
701 |
key_padding_mask=key_padding_mask,
|
702 |
causal=causal,
|
703 |
+
use_reentrant=False
|
704 |
)
|
705 |
|
706 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
|
837 |
|
838 |
config_class = PhiConfig
|
839 |
base_model_prefix = "transformer"
|
840 |
+
supports_gradient_checkpointing = True
|
841 |
_no_split_modules = ["ParallelBlock"]
|
842 |
|
843 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
|
857 |
module.bias.data.zero_()
|
858 |
module.weight.data.fill_(1.0)
|
859 |
|
860 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
861 |
+
if isinstance(module, MHA):
|
862 |
+
module.checkpointing = value
|
863 |
+
|
864 |
def prepare_inputs_for_generation(
|
865 |
self,
|
866 |
input_ids: torch.LongTensor,
|