| | |
| | |
| |
|
| | |
| |
|
| | from functools import partial |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.fx |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| |
|
| | from .mha import MHA |
| | from .mlp import Mlp |
| |
|
| | try: |
| | from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm |
| | except ImportError: |
| | layer_norm_fn, RMSNorm = None, None |
| |
|
| |
|
| | def stochastic_depth( |
| | input: Tensor, p: float, mode: str, training: bool = True |
| | ) -> Tensor: |
| | """ |
| | Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" |
| | <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual |
| | branches of residual architectures. |
| | Args: |
| | input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one |
| | being its batch i.e. a batch with ``N`` rows. |
| | p (float): probability of the input to be zeroed. |
| | mode (str): ``"batch"`` or ``"row"``. |
| | ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes |
| | randomly selected rows from the batch. |
| | training: apply stochastic depth if is ``True``. Default: ``True`` |
| | Returns: |
| | Tensor[N, ...]: The randomly zeroed tensor. |
| | """ |
| | if p < 0.0 or p > 1.0: |
| | raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") |
| | if mode not in ["batch", "row"]: |
| | raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}") |
| | if not training or p == 0.0: |
| | return input |
| |
|
| | survival_rate = 1.0 - p |
| | if mode == "row": |
| | size = [input.shape[0]] + [1] * (input.ndim - 1) |
| | else: |
| | size = [1] * input.ndim |
| | noise = torch.empty(size, dtype=input.dtype, device=input.device) |
| | noise = noise.bernoulli_(survival_rate) |
| | if survival_rate > 0.0: |
| | noise.div_(survival_rate) |
| | return input * noise |
| |
|
| |
|
| | torch.fx.wrap("stochastic_depth") |
| |
|
| |
|
| | class StochasticDepth(nn.Module): |
| | """ |
| | See :func:`stochastic_depth`. |
| | """ |
| |
|
| | def __init__(self, p: float, mode: str) -> None: |
| | super().__init__() |
| | self.p = p |
| | self.mode = mode |
| |
|
| | def forward(self, input: Tensor) -> Tensor: |
| | return stochastic_depth(input, self.p, self.mode, self.training) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})" |
| | return s |
| |
|
| | |
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | mixer_cls=None, |
| | mlp_cls=None, |
| | norm_cls=nn.LayerNorm, |
| | dropout_cls=nn.Dropout, |
| | prenorm=True, |
| | resid_dropout1=0.0, |
| | resid_dropout2=0.0, |
| | drop_path1=0.0, |
| | drop_path2=0.0, |
| | fused_dropout_add_ln=False, |
| | return_residual=False, |
| | residual_in_fp32=False, |
| | sequence_parallel=False, |
| | mark_shared_params=False, |
| | ): |
| | """ |
| | For prenorm=True, this Block has a slightly different structure compared to a regular |
| | prenorm Transformer block. |
| | The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. |
| | [Ref: https://arxiv.org/abs/2002.04745] |
| | Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both |
| | the hidden_states (output of the MLP) and the residual. |
| | This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
| | The residual needs to be provided (except for the very first block). |
| | |
| | For prenorm=False, this Block has the same structure as a regular postnorm Transformer |
| | block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. |
| | |
| | return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. |
| | This is for performance reason: for post-norm architecture, returning the input allows us |
| | to fuse the backward of nn.Linear with the residual connection. |
| | """ |
| | super().__init__() |
| | self.prenorm = prenorm |
| | self.fused_dropout_add_ln = fused_dropout_add_ln |
| | self.return_residual = return_residual |
| | self.residual_in_fp32 = residual_in_fp32 |
| | if self.residual_in_fp32: |
| | assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" |
| | if mixer_cls is None: |
| | mixer_cls = partial(MHA, num_heads=dim // 64) |
| | if mlp_cls is None: |
| | mlp_cls = partial(Mlp, hidden_features=4 * dim) |
| | self.mixer = mixer_cls(dim) |
| | self.dropout1 = dropout_cls(resid_dropout1) |
| | self.drop_path1 = StochasticDepth(drop_path1, mode="row") |
| | self.norm1 = norm_cls(dim) |
| | self.mlp = mlp_cls(dim) |
| | if not isinstance(self.mlp, nn.Identity): |
| | self.dropout2 = dropout_cls(resid_dropout2) |
| | self.drop_path2 = StochasticDepth(drop_path2, mode="row") |
| | self.norm2 = norm_cls(dim) |
| |
|
| | if self.fused_dropout_add_ln: |
| | assert layer_norm_fn is not None, "Triton is not installed" |
| | assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
| | self.dropout1, nn.Dropout |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if sequence_parallel: |
| | for p in self.norm1.parameters(): |
| | p._sequence_parallel = True |
| | if hasattr(self, "norm2"): |
| | for p in self.norm2.parameters(): |
| | p._sequence_parallel = True |
| | |
| | if mark_shared_params: |
| | for p in self.norm1.parameters(): |
| | p._shared_params = True |
| | if hasattr(self, "norm2"): |
| | for p in self.norm2.parameters(): |
| | p._shared_params = True |
| |
|
| | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| | return self.mixer.allocate_inference_cache( |
| | batch_size, max_seqlen, dtype=dtype, **kwargs |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Tensor, |
| | residual: Optional[Tensor] = None, |
| | mixer_subset=None, |
| | mixer_kwargs=None, |
| | ): |
| | r"""Pass the input through the encoder layer. |
| | |
| | Args: |
| | hidden_states: the sequence to the encoder layer (required). |
| | residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) |
| | mixer_subset: for cross-attention only. If not None, will take a subset of x |
| | before applying the query projection. Useful for e.g., ViT where we only care |
| | about the CLS token in the last layer. |
| | """ |
| | if self.prenorm: |
| | if not self.fused_dropout_add_ln: |
| | dropped = self.drop_path1(self.dropout1(hidden_states)) |
| | residual = (dropped + residual) if residual is not None else dropped |
| | hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
| | if self.residual_in_fp32: |
| | residual = residual.to(torch.float32) |
| | else: |
| | if self.drop_path1.p == 0 or not self.training: |
| | rowscale1 = None |
| | else: |
| | rowscale1 = self.drop_path1( |
| | torch.ones( |
| | hidden_states.shape[:-1], |
| | device=hidden_states.device, |
| | dtype=hidden_states.dtype, |
| | ) |
| | ) |
| | hidden_states, residual = layer_norm_fn( |
| | hidden_states, |
| | self.norm1.weight, |
| | self.norm1.bias, |
| | residual=residual, |
| | eps=self.norm1.eps, |
| | dropout_p=self.dropout1.p if self.training else 0.0, |
| | rowscale=rowscale1, |
| | prenorm=True, |
| | residual_in_fp32=self.residual_in_fp32, |
| | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| | ) |
| | if mixer_kwargs is None: |
| | mixer_kwargs = {} |
| | if mixer_subset is not None: |
| | mixer_kwargs["mixer_subset"] = mixer_subset |
| | hidden_states = self.mixer(hidden_states, **mixer_kwargs) |
| | if mixer_subset is not None: |
| | residual = residual[:, mixer_subset] |
| | if not isinstance(self.mlp, nn.Identity): |
| | if not self.fused_dropout_add_ln: |
| | dropped = self.drop_path2(self.dropout2(hidden_states)) |
| | residual = (dropped + residual) if residual is not None else dropped |
| | hidden_states = self.norm2( |
| | residual.to(dtype=self.norm2.weight.dtype) |
| | ) |
| | if self.residual_in_fp32: |
| | residual = residual.to(torch.float32) |
| | else: |
| | if self.drop_path2.p == 0 or not self.training: |
| | rowscale2 = None |
| | else: |
| | rowscale2 = self.drop_path2( |
| | torch.ones( |
| | hidden_states.shape[:-1], |
| | device=hidden_states.device, |
| | dtype=hidden_states.dtype, |
| | ) |
| | ) |
| | hidden_states, residual = layer_norm_fn( |
| | hidden_states, |
| | self.norm2.weight, |
| | self.norm2.bias, |
| | residual=residual, |
| | eps=self.norm2.eps, |
| | dropout_p=self.dropout2.p if self.training else 0.0, |
| | rowscale=rowscale2, |
| | prenorm=True, |
| | residual_in_fp32=self.residual_in_fp32, |
| | is_rms_norm=isinstance(self.norm2, RMSNorm), |
| | ) |
| | hidden_states = self.mlp(hidden_states) |
| | return hidden_states, residual |
| | else: |
| | assert residual is None |
| | mixer_out = self.mixer( |
| | hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) |
| | ) |
| | if self.return_residual: |
| | mixer_out, hidden_states = mixer_out |
| | if not self.fused_dropout_add_ln: |
| | hidden_states = self.norm1( |
| | (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( |
| | dtype=self.norm1.weight.dtype |
| | ) |
| | ) |
| | else: |
| | if self.drop_path1.p == 0 or not self.training: |
| | rowscale1 = None |
| | else: |
| | rowscale1 = self.drop_path1( |
| | torch.ones( |
| | mixer_out.shape[:-1], |
| | device=mixer_out.device, |
| | dtype=mixer_out.dtype, |
| | ) |
| | ) |
| | hidden_states = layer_norm_fn( |
| | mixer_out, |
| | self.norm1.weight, |
| | self.norm1.bias, |
| | residual=hidden_states, |
| | eps=self.norm1.eps, |
| | dropout_p=self.dropout1.p if self.training else 0.0, |
| | rowscale=rowscale1, |
| | prenorm=False, |
| | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| | ) |
| | if not isinstance(self.mlp, nn.Identity): |
| | mlp_out = self.mlp(hidden_states) |
| | if self.return_residual: |
| | mlp_out, hidden_states = mlp_out |
| | if not self.fused_dropout_add_ln: |
| | hidden_states = self.norm2( |
| | (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( |
| | dtype=self.norm2.weight.dtype |
| | ) |
| | ) |
| | else: |
| | if self.drop_path2.p == 0 or not self.training: |
| | rowscale2 = None |
| | else: |
| | rowscale2 = self.drop_path2( |
| | torch.ones( |
| | mlp_out.shape[:-1], |
| | device=mlp_out.device, |
| | dtype=mlp_out.dtype, |
| | ) |
| | ) |
| | hidden_states = layer_norm_fn( |
| | mlp_out, |
| | self.norm2.weight, |
| | self.norm2.bias, |
| | residual=hidden_states, |
| | eps=self.norm2.eps, |
| | dropout_p=self.dropout2.p if self.training else 0.0, |
| | rowscale=rowscale2, |
| | prenorm=False, |
| | is_rms_norm=isinstance(self.norm2, RMSNorm), |
| | ) |
| | return hidden_states |
| |
|
| |
|
| | class ParallelBlock(nn.Module): |
| | """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, |
| | and PaLM. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim, |
| | mixer_cls=None, |
| | mlp_cls=None, |
| | norm_cls=nn.LayerNorm, |
| | dropout_cls=nn.Dropout, |
| | resid_dropout1=0.0, |
| | resid_dropout2=0.0, |
| | tied_norm=False, |
| | fused_dropout_add_ln=False, |
| | residual_in_fp32=False, |
| | sequence_parallel=False, |
| | mark_shared_params=False, |
| | ): |
| | """ |
| | This Block has a slightly different structure compared to a regular |
| | prenorm Transformer block. |
| | The standard block is: LN -> MHA / MLP -> Dropout -> Add. |
| | [Ref: https://arxiv.org/abs/2002.04745] |
| | Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both |
| | the hidden_states (output1 of the MHA / MLP) and the residual. |
| | This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
| | The residual needs to be provided (except for the very first block). |
| | """ |
| | super().__init__() |
| | self.tied_norm = tied_norm |
| | self.fused_dropout_add_ln = fused_dropout_add_ln |
| | self.residual_in_fp32 = residual_in_fp32 |
| | if mixer_cls is None: |
| | mixer_cls = partial(MHA, num_heads=dim // 64) |
| | if mlp_cls is None: |
| | mlp_cls = partial(Mlp, hidden_features=4 * dim) |
| | self.mixer = mixer_cls(dim) |
| | self.dropout1 = dropout_cls(resid_dropout1) |
| | self.norm1 = norm_cls(dim) |
| | self.mlp = mlp_cls(dim) |
| | self.dropout2 = dropout_cls(resid_dropout2) |
| | if not self.tied_norm: |
| | self.norm2 = norm_cls(dim) |
| |
|
| | if self.fused_dropout_add_ln: |
| | assert layer_norm_fn is not None, "Triton is not installed" |
| | assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
| | self.dropout1, nn.Dropout |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if sequence_parallel: |
| | for p in self.norm1.parameters(): |
| | p._sequence_parallel = True |
| | if hasattr(self, "norm2"): |
| | for p in self.norm2.parameters(): |
| | p._sequence_parallel = True |
| | |
| | if mark_shared_params: |
| | for p in self.norm1.parameters(): |
| | p._shared_params = True |
| | if hasattr(self, "norm2"): |
| | for p in self.norm2.parameters(): |
| | p._shared_params = True |
| |
|
| | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| | return self.mixer.allocate_inference_cache( |
| | batch_size, max_seqlen, dtype=dtype, **kwargs |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states1: Tensor, |
| | hidden_states2: Optional[Tensor] = None, |
| | residual: Optional[Tensor] = None, |
| | mixer_kwargs=None, |
| | ): |
| | r"""Pass the input through the encoder layer. |
| | |
| | Args: |
| | hidden_states1: the output of the previous attention (mixer) or embedding layer. |
| | hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). |
| | residual. |
| | """ |
| | |
| | |
| | if not self.fused_dropout_add_ln: |
| | dropped1 = self.dropout1(hidden_states1) |
| | |
| | if hidden_states2 is not None: |
| | dropped2 = self.dropout2(hidden_states2) |
| | residual = ( |
| | (residual + dropped1 + dropped2) |
| | if residual is not None |
| | else dropped1 + dropped2 |
| | ) |
| | else: |
| | residual = (residual + dropped1) if residual is not None else dropped1 |
| | hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
| | hidden_states2 = ( |
| | self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
| | if not self.tied_norm |
| | else hidden_states1 |
| | ) |
| | if self.residual_in_fp32: |
| | residual = residual.to(torch.float32) |
| | else: |
| | weight2, bias2 = ( |
| | (self.norm2.weight, self.norm2.bias) |
| | if not self.tied_norm |
| | else (None, None) |
| | ) |
| | hidden_states1, *rest, residual = layer_norm_fn( |
| | hidden_states1, |
| | self.norm1.weight, |
| | self.norm1.bias, |
| | residual=residual, |
| | x1=hidden_states2, |
| | weight1=weight2, |
| | bias1=bias2, |
| | eps=self.norm1.eps, |
| | dropout_p=self.dropout1.p if self.training else 0.0, |
| | prenorm=True, |
| | residual_in_fp32=self.residual_in_fp32, |
| | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| | ) |
| | if self.tied_norm: |
| | hidden_states2 = hidden_states1 |
| | else: |
| | (hidden_states2,) = rest |
| | if mixer_kwargs is None: |
| | mixer_kwargs = {} |
| | hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) |
| | hidden_states2 = self.mlp(hidden_states2) |
| | return hidden_states1, hidden_states2, residual |
| |
|