enable activation checkpointing
#10
by
smangrul
- opened
- modeling_phi.py +1 -2
modeling_phi.py
CHANGED
@@ -525,7 +525,6 @@ class MHA(nn.Module):
|
|
525 |
softmax_scale: Optional[float] = None,
|
526 |
layer_idx: Optional[int] = None,
|
527 |
return_residual: bool = False,
|
528 |
-
checkpointing: bool = False,
|
529 |
) -> None:
|
530 |
super().__init__()
|
531 |
|
@@ -585,7 +584,7 @@ class MHA(nn.Module):
|
|
585 |
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
586 |
self.layer_idx = layer_idx
|
587 |
self.return_residual = return_residual
|
588 |
-
self.checkpointing = checkpointing
|
589 |
|
590 |
def _forward_self_attn(
|
591 |
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
|
|
525 |
softmax_scale: Optional[float] = None,
|
526 |
layer_idx: Optional[int] = None,
|
527 |
return_residual: bool = False,
|
|
|
528 |
) -> None:
|
529 |
super().__init__()
|
530 |
|
|
|
584 |
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
585 |
self.layer_idx = layer_idx
|
586 |
self.return_residual = return_residual
|
587 |
+
self.checkpointing = getattr(config, "checkpointing", False)
|
588 |
|
589 |
def _forward_self_attn(
|
590 |
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|